Exemplo n.º 1
0
def train(model, data, args):
    """
    Training a CapsuleNet

    Args:
        model: the CapsuleNet model
        data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
        args: arguments

    Returns:
        The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv', separator=',', append=True)
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size, histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        initial_epoch=args.initial_epoch,
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, checkpoint, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 2
0
def train(model, train_gen, val_gen, args):
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    model.fit_generator(generator=train_gen,
                        epochs=args.epochs,
                        validation_data=val_gen,
                        callbacks=[log, tb, checkpoint, lr_decay],
                        use_multiprocessing=True,
                        workers=6,
                        verbose=1)

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 3
0
def train(model, data, args):
    """
    Training a 3-level DCNet
    :param model: the 3-level DCNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """

    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data
    row = x_train.shape[1]
    col = x_train.shape[2]
    channel = x_train.shape[3]

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs', histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5', monitor='val_capsnet_acc',
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

    # compile the model
    # Notice the four separate losses (for separate backpropagations)
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, margin_loss, margin_loss, margin_loss, 'mse'],
                  loss_weights=[1., 1., 1., 1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    #model.load_weights('result/weights.h5')

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, y_train, y_train, y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, y_test, y_test, y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Training with data augmentation
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, y_batch, y_batch, y_batch, x_batch[:,:,:,0:1]])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, y_test, y_test, y_test, x_test[:,:,:,0:1]]],
                        callbacks=[log, tb, checkpoint, lr_decay])

    # Save model weights
    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)

    return model
def train_model(model, args):
    print('Loading train data!')
    images_train, images_mask_train = load_data()

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')

    # 查看tensorboard:
    # methond1:./ python -m tensorboard.main --logdir=./
    # method22: ./ tensorboard --logdir=./
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)

    checkpoint = callbacks.ModelCheckpoint(
        args.save_dir + '/multi-trained_model.h5',
        monitor='val_loss',
        save_best_only=True,
        save_weights_only=True,
        verbose=1,
        mode='min',
    )

    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (0.99**epoch))

    early_stopping = keras.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=args.patience,
        verbose=0,
        mode='min',
    )
    model = ParallelModel(model, args.gpus)

    # 断点续存
    # model = keras.models.load_model(args.save_dir + '/trained_model_old.h5',
    #                                 custom_objects={'bce_dice_loss': bce_dice_loss, 'mean_iou': mean_iou})

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=bce_dice_loss,
                  metrics=["accuracy", mean_iou])

    # Fitting model
    model.fit(images_train,
              images_mask_train,
              batch_size=args.batch_size,
              nb_epoch=args.epochs,
              verbose=1,
              shuffle=True,
              validation_split=0.2,
              callbacks=[log, tb, lr_decay, checkpoint, early_stopping])

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 5
0
def train(save_dir, batch_size, lr, shift_fraction, epochs, model, data, running_time):
    
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data
    class_weights_array = class_weight.compute_class_weight(
                                                            #None
                                                            'balanced'
                                               ,np.unique(np.argmax(y_train, axis=1))
                                               ,np.argmax(y_train, axis=1))
    
    class_weights={0:class_weights_array[1],1:class_weights_array[0]}
    # callbacks
    
    log_dir = save_dir + '\\tensorboard-logs-dd' + '\\' + running_time
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    log = callbacks.CSVLogger(save_dir + '\\log-dd.csv')
    tb = callbacks.TensorBoard(log_dir=log_dir,
                               batch_size=batch_size)
    checkpoint = callbacks.ModelCheckpoint(save_dir + '\\weights-dd-{epoch:02d}.h5', monitor='val_acc',
                                           save_best_only=True, save_weights_only=True, verbose=1)

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=lr),
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])

    # Begin: Training with data augmentation---------------------------------------------------------------------#
    
    def train_generator(x, y, batch_size, savedir, shift_fraction=0.):
        if not os.path.exists(savedir):
            os.makedirs(savedir)
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           samplewise_std_normalization = False,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size, shuffle=False)
        while 1:
            x_batch, y_batch = generator.next()
            yield (x_batch, y_batch)

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    print(class_weights)

    model.fit(x_train, y_train, batch_size=batch_size,
              epochs=epochs,
              validation_data=(x_test, y_test),
              callbacks=[log, tb, checkpoint],
              class_weight= class_weights,
              shuffle=True)

    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(save_dir + 'trained_model_dd_toxo.h5')
    print('Trained model saved to \'%s \\trained_mode_dd_toxo.h5\'' % save_dir)

    
    plot_log(os.path.join(save_dir, 'log-dd.csv'), show=True)

    return model
Exemplo n.º 6
0
def train(model, args):
    # output callback func
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=False,
                                           save_weights_only=False,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    # x = np.load('sNORB_img_train_azimuths30_32_34_0_2_4.npy')
    # y = np.load('sNORB_lbl_train_azimuths30_32_34_0_2_4.npy')
    x = np.load('luna-images-merged08.npy')
    y = np.load('luna-labels-merged08.npy')

    x = np.array([resizeIMG(i) for i in np.array(x)])

    # pos_cnt = 0
    # for lab in y:
    #     if(np.argmax(lab) == 1):
    #         pos_cnt += 1
    #     # print(str(lab))
    # print("Positive: {0}, negative {1}".format(str(pos_cnt), str(len(y)-pos_cnt)))
    (train_images, val_images, train_labels,
     val_labels) = train_test_split(x,
                                    y,
                                    train_size=0.9,
                                    test_size=0.1,
                                    random_state=45)
    print("Using {0} images for training and {1} for validation.".format(
        str(len(train_images)), str(len(val_images))))

    gen_b = batch_generator(train_images, train_labels)
    gen_v = validation_generator(val_images, val_labels)
    model.fit_generator(generator=gen_b,
                        steps_per_epoch=conf.STEPS_PER_EPOCH,
                        epochs=args.epochs,
                        validation_data=gen_v,
                        validation_steps=conf.VALIDATION_STEPS,
                        callbacks=[log, tb, checkpoint, lr_decay])

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=False)

    return model
Exemplo n.º 7
0
def train(model, data, args, dirs):
    """
    The function which defines the training loop of the model

    Parameters
    ----------
    model : `keras.models.Model`
        The structure of the model which is to be trained
    data : `tuple`
        The training and validation data
    args : `dict`
        The argument dictionary which defines other parameters at training time
    dirs : `string`
        Filepath to store the logs
    """

    # Extract the data
    (x_train, y_train), (x_val, y_val) = data

    # callbacks
    log = callbacks.CSVLogger(dirs + '/log.csv')

    tb = callbacks.TensorBoard(log_dir=dirs + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))

    checkpoint = callbacks.ModelCheckpoint(dirs + '/model.h5',
                                           monitor='val_acc',
                                           save_best_only=True,
                                           save_weights_only=False,
                                           verbose=1)

    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss='binary_crossentropy',
                  metrics=['acc'])

    # Training without data augmentation:
    model.fit(x_train,
              y_train,
              batch_size=args.batch_size,
              epochs=args.epochs,
              verbose=1,
              validation_data=(x_val, y_val),
              callbacks=[
                  log, tb, checkpoint, lr_decay
              ])  #, roc_auc_callback((x_train, y_train), (x_val, y_val))])

    # Save the trained model
    model.save(dirs + '/trained_model.h5')

    # Plot the training results
    plot_log(dirs, show=False)

    return model
Exemplo n.º 8
0
def train(model, data, args):
    """
    训练胶囊网络
    ## model: 胶囊网络模型
    ## data: 包含训练和测试数据,形如((x_train, y_train), (x_test, y_test))
    ## args: 输入的训练参数
    :return: 训练好的模型
    """
    # 解包数据
    (x_train, y_train), (x_test, y_test) = data

    # 定义回调函数
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # 编译模型
    model.compile(
        optimizer=optimizers.Adam(lr=args.lr),
        loss=[margin_loss, 'mse'],  #有两个损失函数, 编码器损失函数和解码器损失函数
        loss_weights=[1., args.lam_recon],  #两个损失函数占的权重
        metrics={'capsnet': 'accuracy'})

    # Begin: 开始训练 ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    model.fit_generator(
        generator=train_generator(x_train, y_train, args.batch_size,
                                  args.shift_fraction),
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[[x_test, y_test], [y_test, x_test]],
        callbacks=[log, tb, checkpoint, lr_decay])
    # End: 结束训练 -----------------------------------------------------------------------#
    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 9
0
def train(model, data, args):
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.hdf5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(
        optimizer=optimizers.Adam(lr=args.lr),
        loss=[margin_loss, reconstruction_loss
              ],  # We scale down this reconstruction loss by 0.0005 so that
        loss_weights=[
            1., args.scale_reconstruction_loss
        ],  # ...it does not dominate the margin loss during training.
        metrics={'capsnet': 'accuracy'})

    # Generator with data augmentation as used in [1]
    def train_generator_with_augmentation(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(
            width_shift_range=shift_fraction,
            height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    generator = train_generator_with_augmentation(x_train, y_train,
                                                  args.batch_size,
                                                  args.shift_fraction)
    model.fit_generator(
        generator=generator,
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[
            [x_test, y_test], [y_test, x_test]
        ],  # Note: For the decoder the input is the label and the output the image
        callbacks=[log, tb, checkpoint, lr_decay])

    model.save_weights(args.save_dir + '/trained_model.hdf5')
    print('Trained model saved to \'%s/trained_model.hdf5\'' % args.save_dir)

    utils.plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 10
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size, histogram_freq=args.debug)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: 0.001 * np.exp(-epoch / 10.))

    # compile the model
    model.compile(optimizer='adam',
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'out_caps': 'accuracy'})

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint])
    """

    # -----------------------------------Begin: Training with data augmentation -----------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, checkpoint, lr_decay])
    # -----------------------------------End: Training with data augmentation -----------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 11
0
def train(model, args):
    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    ##my block----------------------------------------------------------------------------------------------------#
    train_datagen = ImageDataGenerator(validation_split=val_split,
                                       rescale=1. / 255)

    train_generator = train_datagen.flow_from_directory(
        train_data_dir,
        shuffle=False,
        target_size=(img_height, img_width),
        batch_size=args.batch_size,
        class_mode='categorical',
        subset='training')

    validation_generator = train_datagen.flow_from_directory(
        train_data_dir,
        shuffle=False,
        target_size=(img_height, img_width),
        batch_size=args.batch_size,
        class_mode='categorical',
        subset='validation')

    k = TestCallback(validation_generator)

    model.fit(train_generator,
              steps_per_epoch=(totalpics * (1 - val_split)) // args.batch_size,
              epochs=args.epochs,
              callbacks=[log, checkpoint, lr_decay, k],
              validation_data=validation_generator,
              validation_steps=(totalpics * (val_split)) // args.batch_size)
    ##------------------------------------------------------------------------------------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 12
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # wandb.init(name=RUN_NAME, project="6_class", notes=COMMENTS)

    # unpacking the data
    ([x_train_channelb, x_train_channelc, x_train_channeld],
     y_train), ([x_test_channelb, x_test_channelc,
                 x_test_channeld], y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss],
                  loss_weights=[1.],
                  metrics={'capsnet': 'accuracy'})

    model.fit(
        [x_train_channelb, x_train_channelc, x_train_channeld],
        y_train,
        batch_size=args.batch_size,
        epochs=args.epochs,
        validation_data=[[x_test_channelb, x_test_channelc, x_test_channeld],
                         y_test],
        callbacks=[log, tb, checkpoint, lr_decay])
    # model.fit([x_train_channelb, x_train_channelc, x_train_channeld], y_train, batch_size=args.batch_size, epochs=args.epochs,
    #           validation_data=[[x_test_channelb, x_test_channelc, x_test_channeld], y_test], callbacks=[log, tb, checkpoint, lr_decay, WandbCallback()])

    model.save_weights(args.save_dir + '/trained_model.h5')
    # model.save(os.path.join(wandb.run.dir, "model.h5"))
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 13
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (0.9**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon])
    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(
            width_shift_range=shift_fraction,
            height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(
        generator=train_generator(x_train, y_train, args.batch_size,
                                  args.shift_fraction),
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[[x_test, y_test], [y_test, x_test]],
        callbacks=[log, tb, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 14
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data
    
    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size, histogram_freq=args.debug)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.h5',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (0.9 ** epoch))

    # compile the model
    
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'out_caps': 'accuracy'})
    
    
    # Training without data augmentation:
    if len(x_test) == 0:
        model.fit([x_train, y_train], [y_train, x_train], 
              batch_size=args.batch_size, 
              epochs=args.epochs,
              validation_split=0.2, 
              callbacks=[log, tb, checkpoint, lr_decay])
    else:
        model.fit([x_train, y_train], [y_train, x_train], 
              batch_size=args.batch_size, 
              epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], 
              callbacks=[log, tb, checkpoint, lr_decay])
    
    weights_file = os.path.join(args.save_dir,args.save_weight)
    model.save_weights(weights_file)
    print('Trained model saved to {}'.format( weights_file))

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 15
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size, histogram_freq=args.debug)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (0.9 ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon])

    """
    # Training without data augmentation:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size, args.shift_fraction),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[log, tb, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 16
0
def train(model, train_seq, validation_seq, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(
        log_dir=args.save_dir + '/tensorboard-logs',
        batch_size=args.batch_size,
        histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(
        args.save_dir + '/weights-{epoch:02d}.h5',
        monitor='val_acc',
        save_best_only=True,
        save_weights_only=True,
        verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(
        optimizer=optimizers.Adam(lr=args.lr),
        loss=margin_loss,
        metrics={'capsnet': 'accuracy'})
    """
    # Training without data augmentation:
    """
    model.fit_generator(
        generator=train_seq,
        validation_data=validation_seq,
        epochs=args.epochs,
        class_weight={
            0: 1,
            1: 1.8669997421
        },
        callbacks=[log, tb, checkpoint, lr_decay])

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 17
0
def train(model, train_list, val_list):

    (x_train, y_train) = train_list
    (x_test, y_test) = val_list

    # Compile the loaded model
    opt = Adam(lr=args.lr, beta_1=0.99, beta_2=0.999, decay=1e-6)
    metrics = {'out_seg': [dice_hard, dice_hard_intersection, dice_hard_union]}
    loss = {'out_seg': 'binary_crossentropy', 'out_recon': 'mse'}
    loss_weighting = {'out_seg': 1., 'out_recon': args.recon_wei}

    training_model.compile(optimizer=opt, loss=loss, loss_weights=loss_weighting, metrics=metrics)

    #######################################
    #               CallBacks             #
    #######################################

    monitor_name = 'val_out_seg_dice_hard_intersection'

    csv_logger = CSVLogger(join(args.logs_dir, 'logs_'+str(args.time)+'.csv'))
    tb = TensorBoard(join(args.logs_dir, 'tensorflow_logs'), batch_size=args.batch_size, histogram_freq=0)
    model_checkpoint = ModelCheckpoint(join(args.weights_dir, 'weights_{epoch:04d}.h5'),
                                       monitor=monitor_name, save_weights_only=True,
                                       verbose=1)
    lr_reducer = ReduceLROnPlateau(monitor=monitor_name, factor=0.05, cooldown=0, patience=5, verbose=1, mode='max')
    early_stopper = EarlyStopping(monitor=monitor_name, min_delta=0, patience=25, verbose=0, mode='max')

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size):
        train_datagen = ImageDataGenerator()
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(generator=train_generator(x_train, y_train, args.batch_size),
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[[x_test, y_test], [y_test, x_test]],
                        callbacks=[model_checkpoint, csv_logger, lr_reducer, early_stopper, tb], verbose=1)

    model.save_weights(join(args.models_dir, "model_"+str(args.time)+".h5"))
    print("=" * 50)
    print("Model saved to --> " + str(join(args.models_dir, "model_"+str(args.time)+".h5")))
    print("=" * 50)

    plot_log(args.logs_dir,'logs_'+str(args.time)+'.csv', save=True, verbose=args.debug)
Exemplo n.º 18
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args['save_dir'] + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args['save_dir'] + '/tensorboard-logs',
                               batch_size=args['batch_size'],
                               histogram_freq=int(args['debug']))
    checkpoint = callbacks.ModelCheckpoint(args['save_dir'] +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args['lr'] * (args['lr_decay']**epoch))
    early_stop = callbacks.EarlyStopping(monitor='val_loss',
                                         patience=3,
                                         verbose=1)

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args['lr']),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args['lam_recon']],
                  metrics={'capsnet': 'accuracy'})

    # Training
    model.fit([x_train, y_train], [y_train, x_train],
              batch_size=args['batch_size'],
              epochs=args['epochs'],
              validation_data=[[x_test, y_test], [y_test, x_test]],
              callbacks=[log, tb, checkpoint, lr_decay, early_stop])

    model.save_weights(args['save_dir'] + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args['save_dir'])

    plot_log(args['save_dir'] + '/log.csv', show=True)

    return model
Exemplo n.º 19
0
def train(model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: 0.001 * np.exp(-epoch / 10.))

    # compile the model
    model.compile(optimizer='adam',
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'out_caps': 'accuracy'})

    model.fit([x_train, y_train], [y_train, x_train],
              batch_size=args.batch_size,
              epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]],
              callbacks=[log, tb, checkpoint],
              verbose=1)

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 20
0
def train(model, data, args):

    (x_train, y_train), (x_test, y_test) = data

    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/best-weight.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (0.9**epoch))

    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    def train_generator(x, y, batch_size):
        train_datagen = ImageDataGenerator(width_shift_range=0.1,
                                           height_shift_range=0.1)

        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    model.fit_generator(
        generator=train_generator(x_train, y_train, args.batch_size),
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[[x_test, y_test], [y_test, x_test]],
        callbacks=[log, tb, checkpoint, lr_decay])

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 21
0
def train(model, data, args):
    # unpacking the data
    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size, histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir + '/weights-{epoch:02d}.hdf5', monitor='val_acc',
                                           save_best_only=True, save_weights_only=True, verbose=1)
    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: args.lr * (args.lr_decay ** epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=categorical_crossentropy,
                  metrics=['accuracy'])      

    # Generator with data augmentation as used in [1] ([...] also trained on 2-pixel shifted MNIST)
    def train_generator_with_augmentation(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)  # shift up to 2 pixel for MNIST
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            if args.crop_x is not None and args.crop_y is not None:
                x_batch = utils.random_crop(x_batch, [args.crop_x, args.crop_y])  
            yield (x_batch, y_batch)

    
    generator = train_generator_with_augmentation(x_train, y_train, args.batch_size, args.shift_fraction)
    model.fit_generator(generator=generator,
                        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
                        epochs=args.epochs,
                        validation_data=[x_test, y_test],
                        callbacks=[log, tb, checkpoint, lr_decay])

    model.save_weights(args.save_dir + '/trained_model.hdf5')
    print('Trained model saved to \'%s/trained_model.hdf5\'' % args.save_dir)

    utils.plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 22
0
def train(model, train_seq, validation_seq, args):
    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)

    # compile the model
    optimizer, losses, loss_weights, reduce_lr = ResNet101HyperParameters(
        args.lr)
    model.compile(optimizer=optimizer,
                  loss=losses,
                  loss_weights=loss_weights,
                  metrics=['accuracy'])

    # Training without data augmentation:
    model.fit_generator(generator=train_seq,
                        validation_data=validation_seq,
                        epochs=args.epochs,
                        class_weight={
                            0: 1,
                            1: 1.8669997421
                        },
                        callbacks=[log, tb, checkpoint, reduce_lr])

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)
    return model
Exemplo n.º 23
0
def train(model, data, args):
    X, Y = data

    # Callbacks
    log = callbacks.CSVLogger(filename=args.save_dir + '/log.csv')

    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)

    checkpoint = callbacks.ModelCheckpoint(filepath=args.save_dir + '/weights-improvement-{epoch:02d}.hdf5',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)

    lr_decay = callbacks.LearningRateScheduler(schedule=lambda epoch: 0.001 * np.exp(-epoch / 10.))

    # compile the model
    model.compile(optimizer='adam',
                  loss=[margin_loss],
                  metrics=['accuracy'])

    model.fit(x=X,
              y=Y,
              validation_split=0.2,
              batch_size=args.batch_size,
              epochs=args.epochs,
              callbacks=[log, tb, checkpoint, lr_decay],
              shuffle=True,
              verbose=1)

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)
    return model
Exemplo n.º 24
0
def train(model, eval_model, data, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    (x_train, y_train), (x_test, y_test), classes = data

    print("x_train {}, y_train {}, x_test {}, y_test {}".format(
        x_train.shape, y_train.shape, x_test.shape, y_test.shape))

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_rec_macro',
                                           mode='max',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    if os.path.isfile(args.save_dir + '/trained_model.h5'):
        model.load_weights(args.save_dir + '/trained_model.h5')
    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})
    mc = MetricCallback(validation_data=((x_test, y_test), (y_test, x_test)),
                        labels=classes,
                        batch_size=args.batch_size)
    model.fit([x_train, y_train], [y_train, x_train],
              batch_size=args.batch_size,
              epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]],
              callbacks=[mc, log, tb, checkpoint, lr_decay],
              shuffle=True)

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    plot_log(args.save_dir + '/log.csv', show=True)

    y_pred = eval_model.predict(
        x_test, batch_size=args.batch_size)[0].astype("float32")
    acc = accuracy_score(y_test, y_pred)
    cm = confusion_matrix(y_test, y_pred)
    recall = recall_score(y_test, y_pred, average="macro")
    print("Accuracy: {:.2f}%".format(acc * 100))
    print("Recall score: {:.2f}%".format(recall * 100))
    print("Confusion matrix:\n{}".format(cm))

    return model
def train(model, data, args):
    """
   Kapsül Ağının Eğitimi
              : "model" parametresi: CapsNet (Kapsül Ağ) Modeli
              :"data" parametresi: Eğitim ve test verisinden bir grup içerir, örneğin; `((x_train, y_train), (x_test, y_test))`
              :"args" parametresi: Bağımsız değişkenler
              : Fonksiyon çıktısı: Eğitilmiş model
    """

    # Verilerin Kullanıma Hazır Hale Getir
    (x_train, y_train), (x_test, y_test) = data

    # Tutulacak Kayıtlar
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (0.9**epoch))

    # Model Derlenir
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})

    # EĞİTİM ÖNCESİ VERİ ARTIRMA (DATA AUGMENTATION) YAPALIM
    ### VERİ ARTIRMA BAŞLA
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(width_shift_range=shift_fraction,
                                           height_shift_range=shift_fraction)
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Veri Artırma Yaparak Modelin Eğitimi. Eğer shift_fraction=0., Bu durumda da veri artırma olmaz.
    model.fit_generator(
        generator=train_generator(x_train, y_train, args.batch_size,
                                  args.shift_fraction),
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[[x_test, y_test], [y_test, x_test]],
        callbacks=[log, tb, checkpoint, lr_decay])
    """
    # Veri Artırma (Data Augmentation) Yapmadan Modelin Eğitimi:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    """

    ### VERİ ARTIRMA BİTİR

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 26
0
                           batch_size=batch_size)
log = callbacks.CSVLogger(saveDir + '/log.csv')

for layer in full_model.layers[1].layers[:-1]:
    layer.trainable = False

history = full_model.fit_generator(generator=train_generator(batch_size),
                                   steps_per_epoch=int(520 / batch_size),
                                   epochs=epochs,
                                   verbose=1,
                                   validation_data=val_generator(batch_size),
                                   validation_steps=int(208 / batch_size),
                                   callbacks=[es_cb, tb, log])
from utils import plot_log

plot_log(saveDir + '/log.csv', show=True)
#

test_eval = []


def test_generator(batch_size=batch_size):
    test_datagen = ImageDataGenerator(rescale=1. / 255)
    generator_test = test_datagen.flow_from_directory(
        directory="D:/EURECOM_Kinect_Face_Dataset/RGB/test",
        target_size=(224, 224),
        color_mode="rgb",
        batch_size=1,
        class_mode="categorical",
        shuffle=True,
        seed=42)
def train(model, data, args):
    """
   Kapsül Ağının Eğitimi    
              :param model:CapsNet (Kapsül Ağ) Modeli 
              :param data:Eğitim ve test verisinden bir grup içerir, örneğin; `((x_train, y_train), (x_test, y_test))`
              :param args:Bağımsız değişkenler
              :return :Eğitilmiş model

    """

    # Verilerin Kullanıma Hazır Hale Gelir
    (x_train, y_train), (x_test, y_test) = data

    # Geri Çağırmalar
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=args.debug)
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (0.9**epoch))

    # Model Derlenir
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})
    """
    # Veri Büyütme Yapmadan Modelin Eğitimi:
    model.fit([x_train, y_train], [y_train, x_train], batch_size=args.batch_size, epochs=args.epochs,
              validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])

    """

    # Başlangıç: Veri Büyütme Yaparak Modelin Eğitimi ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(
            width_shift_range=shift_fraction,
            height_shift_range=shift_fraction
        )  # MNIST veri setini 2 piksel yukarı kaydırır
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # Veri Büyütme Yaparak Modelin Eğitimi. Eğer shift_fraction=0., Bu durumda da veri büyütme olmaz.
    model.fit_generator(
        generator=train_generator(x_train, y_train, args.batch_size,
                                  args.shift_fraction),
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[[x_test, y_test], [y_test, x_test]],
        callbacks=[log, tb, checkpoint, lr_decay])

    # Son: Veri Büyütme Yaparak Modelin Eğitimi -----------------------------------------------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
def train(model, fold, args):
    """
    Training 
    :param model: the  model
    
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    #    (x_train, y_train), (x_test, y_test) = data

    # callbacks

    save_dirc = args.save_dir + str(fold)
    log = callbacks.CSVLogger(save_dirc + '/log.csv')
    es_cb = callbacks.EarlyStopping(monitor='val_loss',
                                    patience=10,
                                    verbose=1,
                                    mode='auto')
    tb = callbacks.TensorBoard(log_dir=save_dirc + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(save_dirc + '/weights-best.h5',
                                           monitor='val_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=['categorical_crossentropy'],
                  metrics=['accuracy'])

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(batch_size, val_train):
        #        train_datagen = ImageDataGenerator(rescale=1./255,shear_range=0.2,zoom_range=0.2,horizontal_flip=True, validation_split=0.2)
        #
        #        train_generator = train_datagen.flow_from_directory("D:/RGB_D_Dataset/train/depth/",target_size=(224, 224), color_mode="rgb",
        #                                                            batch_size=args.batch_size, class_mode='categorical',subset='training')
        #
        #        validation_generator = train_datagen.flow_from_directory("D:/RGB_D_Dataset/train/depth/",target_size=(224, 224), color_mode="rgb",
        #                                                                 batch_size=args.batch_size,class_mode='categorical',subset='validation')
        #
        batch_size = int(batch_size / 5)
        train_datagen = ImageDataGenerator(rescale=1. / 255,
                                           shear_range=0.2,
                                           zoom_range=0.2,
                                           horizontal_flip=True)
        generator_rgb = train_datagen.flow_from_directory(
            directory="D:/RGB_D_Dataset_new/fold{}/train/RGB/".format(fold),
            target_size=(224, 224),
            color_mode="rgb",
            batch_size=batch_size,
            class_mode="categorical",
            shuffle=True,
            seed=42)
        generator_depth = train_datagen.flow_from_directory(
            directory="D:/RGB_D_Dataset_new/fold{}/train/depth/".format(fold),
            target_size=(224, 224),
            color_mode="rgb",
            batch_size=batch_size,
            class_mode="categorical",
            shuffle=True,
            seed=42)

        generator_rgb_val = train_datagen.flow_from_directory(
            directory="D:/RGB_D_Dataset_new/fold{}/test/RGB/".format(fold),
            target_size=(224, 224),
            color_mode="rgb",
            batch_size=batch_size,
            class_mode="categorical",
            shuffle=True,
            seed=42)
        generator_depth_val = train_datagen.flow_from_directory(
            directory="D:/RGB_D_Dataset_new/fold{}/test/depth/".format(fold),
            target_size=(224, 224),
            color_mode="rgb",
            batch_size=batch_size,
            class_mode="categorical",
            shuffle=True,
            seed=42)
        if val_train == 'train':
            while 1:
                #rgb data aug
                x_batch_rgb, y_batch_rgb = generator_rgb.next()
                flip_img = iaa.Fliplr(1)(images=x_batch_rgb)
                rot_img = iaa.Affine(rotate=(-30, 30))(images=x_batch_rgb)

                shear_aug = iaa.Affine(shear=(-16, 16))(images=x_batch_rgb)
                trans_aug = iaa.Affine(scale={
                    "x": (0.5, 1.5),
                    "y": (0.5, 1.5)
                })(images=x_batch_rgb)
                x_batch_rgb_final = np.concatenate(
                    [x_batch_rgb, flip_img, rot_img, shear_aug, trans_aug],
                    axis=0)
                y_batch_rgb_final = np.tile(y_batch_rgb, (5, 1))
                ## depth data aug
                x_batch_depth, y_batch_depth = generator_depth.next()
                flip_img = iaa.Fliplr(1)(images=x_batch_depth)
                rot_img = iaa.Affine(rotate=(-30, 30))(images=x_batch_depth)

                shear_aug = iaa.Affine(shear=(-16, 16))(images=x_batch_depth)
                trans_aug = iaa.Affine(scale={
                    "x": (0.5, 1.5),
                    "y": (0.5, 1.5)
                })(images=x_batch_depth)
                x_batch_depth_final = np.concatenate(
                    [x_batch_depth, flip_img, rot_img, shear_aug, trans_aug],
                    axis=0)
                y_batch_depth_final = np.tile(y_batch_rgb, (5, 1))
                yield [[x_batch_rgb_final, x_batch_depth_final],
                       y_batch_rgb_final]
        elif val_train == 'val':
            while 1:
                x_batch_rgb, y_batch_rgb = generator_rgb_val.next()
                x_batch_depth, y_batch_depth = generator_depth_val.next()
                yield [[x_batch_rgb, x_batch_depth], y_batch_rgb]

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(
        generator=train_generator(args.batch_size, 'train'),
        steps_per_epoch=int(
            424 /
            int(args.batch_size / 5)),  ##936 curtin faces###424 fold1 iiitd
        epochs=args.epochs,
        validation_data=train_generator(args.batch_size, 'val'),
        validation_steps=int(
            4181 /
            int(args.batch_size / 5)),  ##4108 curtin faces###4181 fold1 iiitd
        callbacks=[log, tb, checkpoint, lr_decay, es_cb])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(save_dirc + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % save_dirc)

    from utils import plot_log
    plot_log(save_dirc + '/log.csv', show=True)

    return model
def train(model, args):
    """
    Training 
    :param model: the  model
    
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    #    (x_train, y_train), (x_test, y_test) = data

    # callbacks

    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    es_cb = callbacks.EarlyStopping(monitor='loss',
                                    patience=10,
                                    verbose=1,
                                    mode='auto')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=['categorical_crossentropy'],
                  metrics=['accuracy'])

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    def train_generator(batch_size, val_train):

        #

        train_datagen = ImageDataGenerator(
            rescale=1. / 255,
            shear_range=0.2,
            zoom_range=0.2,
            horizontal_flip=True)  #, validation_split=0.2)
        generator_rgb = train_datagen.flow_from_directory(
            directory="D:/CurtinFaces_crop/RGB/train/",
            target_size=(224, 224),
            color_mode="rgb",
            batch_size=batch_size,
            class_mode="categorical",
            shuffle=True,
            subset='training',
            seed=42)
        generator_depth = train_datagen.flow_from_directory(
            directory="D:/CurtinFaces_crop/normalized/DEPTH/train/",
            target_size=(224, 224),
            color_mode="rgb",
            batch_size=batch_size,
            class_mode="categorical",
            shuffle=True,
            subset='training',
            seed=42)

        #        generator_rgb_val = train_datagen.flow_from_directory(directory="D:/CurtinFaces_crop/RGB/train/", target_size=(224, 224), color_mode="rgb",
        #                                                      batch_size=batch_size, class_mode="categorical", shuffle=True,subset='validation', seed=42)
        #        generator_depth_val = train_datagen.flow_from_directory(directory="D:/CurtinFaces_crop/normalized/DEPTH/train/", target_size=(224, 224), color_mode="rgb",
        #                                                      batch_size=batch_size, class_mode="categorical", shuffle=True,subset='validation', seed=42)
        if val_train == 'train':
            while 1:
                x_batch_rgb, y_batch_rgb = generator_rgb.next()
                x_batch_depth, y_batch_depth = generator_depth.next()
                yield [[x_batch_rgb, x_batch_depth], y_batch_rgb]
        elif val_train == 'val':
            while 1:
                x_batch_rgb, y_batch_rgb = generator_rgb_val.next()
                x_batch_depth, y_batch_depth = generator_depth_val.next()
                yield [[x_batch_rgb, x_batch_depth], y_batch_rgb]

    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(
        generator=train_generator(args.batch_size, 'train'),
        steps_per_epoch=int(
            936 / args.batch_size),  ##936 curtin faces###424 fold1 iiitd
        epochs=args.epochs,
        #                        validation_data=train_generator(args.batch_size,'val'),
        #                        validation_steps = int( 156 / args.batch_size),##4108 curtin faces###4181 fold1 iiitd
        callbacks=[log, tb, checkpoint, lr_decay, es_cb])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    #    model.save_weights(args.save_dir + '/trained_model.h5')
    #    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
def train(model, args):
    """
    Training a CapsuleNet
    :param model: the CapsuleNet model
    :param data: a tuple containing training and testing data, like `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: The trained model
    """
    # unpacking the data
    #    (x_train, y_train), (x_test, y_test) = data

    # callbacks
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # compile the model
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=['categorical_crossentropy'],
                  metrics=['accuracy'])

    # Begin: Training with data augmentation ---------------------------------------------------------------------#
    #    def train_generator(batch_size):
    #        train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
    #        generator_rgb = train_datagen.flow_from_directory(directory="D:/RGB_D_Dataset/train/RGB/", target_size=(224, 224), color_mode="rgb",
    #                                                      batch_size=batch_size, class_mode="categorical", shuffle=True, seed=42)
    #        generator_depth = train_datagen.flow_from_directory(directory="D:/RGB_D_Dataset/train/depth/", target_size=(224, 224), color_mode="rgb",
    #                                                      batch_size=batch_size, class_mode="categorical", shuffle=True, seed=42)
    #
    #        while 1:
    #            x_batch_rgb, y_batch_rgb = generator_rgb.next()
    #            x_batch_depth, y_batch_depth = generator_depth.next()
    #            yield [x_batch_depth, y_batch_depth]
    #
    #    def val_generator(batch_size):
    #        train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
    #        generator_rgb = train_datagen.flow_from_directory(directory="D:/RGB_D_Dataset/train/RGB/", target_size=(224, 224), color_mode="rgb",
    #                                                      batch_size=batch_size, class_mode="categorical", shuffle=True, seed=42)
    #        generator_depth = train_datagen.flow_from_directory(directory="D:/RGB_D_Dataset/test/depth/", target_size=(224, 224), color_mode="rgb",
    #                                                      batch_size=batch_size, class_mode="categorical", shuffle=True, seed=42)
    #
    #        while 1:
    #            x_batch_rgb, y_batch_rgb = generator_rgb.next()
    #            x_batch_depth, y_batch_depth = generator_depth.next()
    #            yield [x_batch_depth, y_batch_depth]

    train_datagen = ImageDataGenerator(rescale=1. / 255,
                                       shear_range=0.2,
                                       zoom_range=0.2,
                                       horizontal_flip=True,
                                       validation_split=0.2)

    train_generator = train_datagen.flow_from_directory(
        "D:/RGB_D_Dataset/train/depth/",
        target_size=(224, 224),
        color_mode="rgb",
        batch_size=args.batch_size,
        class_mode='categorical',
        subset='training')

    validation_generator = train_datagen.flow_from_directory(
        "D:/RGB_D_Dataset/train/depth/",
        target_size=(224, 224),
        color_mode="rgb",
        batch_size=args.batch_size,
        class_mode='categorical',
        subset='validation')
    # Training with data augmentation. If shift_fraction=0., also no augmentation.
    model.fit_generator(
        generator=train_generator,
        steps_per_epoch=int(train_generator.samples / args.batch_size),
        epochs=args.epochs,
        validation_data=validation_generator,
        validation_steps=int(validation_generator.samples / args.batch_size),
        callbacks=[log, tb, checkpoint, lr_decay])
    # End: Training with data augmentation -----------------------------------------------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
Exemplo n.º 31
0
def train(model, data, args):
    """
    カプセルネットの学習
    :param model: カプセルネットのモデル
    :param data: trainデータもtestデータも含んだタプルデータ, 右のような形式 `((x_train, y_train), (x_test, y_test))`
    :param args: arguments
    :return: 学習ずみモデル
    """
    # データ切り分け
    (x_train, y_train), (x_test, y_test) = data

    # コールバック
    log = callbacks.CSVLogger(args.save_dir + '/log.csv')
    tb = callbacks.TensorBoard(log_dir=args.save_dir + '/tensorboard-logs',
                               batch_size=args.batch_size,
                               histogram_freq=int(args.debug))
    checkpoint = callbacks.ModelCheckpoint(args.save_dir +
                                           '/weights-{epoch:02d}.h5',
                                           monitor='val_capsnet_acc',
                                           save_best_only=True,
                                           save_weights_only=True,
                                           verbose=1)
    lr_decay = callbacks.LearningRateScheduler(
        schedule=lambda epoch: args.lr * (args.lr_decay**epoch))

    # モデルをコンパイル
    model.compile(optimizer=optimizers.Adam(lr=args.lr),
                  loss=[margin_loss, 'mse'],
                  loss_weights=[1., args.lam_recon],
                  metrics={'capsnet': 'accuracy'})
    """
    データ増強無しの学習をする場合:
    model.fit([x_train, y_train], [y_train, x_train], bacth_size=args.batch_size, epochs=args.epochs,
               validation_data=[[x_test, y_test], [y_test, x_test]], callbacks=[log, tb, checkpoint, lr_decay])
    としてください
    """

    # 開始 : データ増強ありの学習 ---------------------------------------------------------------------#
    def train_generator(x, y, batch_size, shift_fraction=0.):
        train_datagen = ImageDataGenerator(
            width_shift_range=shift_fraction,
            height_shift_range=shift_fraction)  # MNISTの画像を2ピクセルをずらす
        generator = train_datagen.flow(x, y, batch_size=batch_size)
        while 1:
            x_batch, y_batch = generator.next()
            yield ([x_batch, y_batch], [y_batch, x_batch])

    # shift_fraction=0とすれば, データ増強無しの学習.
    model.fit_generator(
        generator=train_generator(x_train, y_train, args.batch_size,
                                  args.shift_fraction),
        steps_per_epoch=int(y_train.shape[0] / args.batch_size),
        epochs=args.epochs,
        validation_data=[[x_test, y_test], [y_test, x_test]],
        callbacks=[log, tb, checkpoint, lr_decay])
    # 終了 : データ増強無しの学習 -----------------------------------------------------------------------#

    model.save_weights(args.save_dir + '/trained_model.h5')
    print('Trained model saved to \'%s/trained_model.h5\'' % args.save_dir)

    from utils import plot_log
    plot_log(args.save_dir + '/log.csv', show=True)

    return model
checkpoint = tf.keras.callbacks.ModelCheckpoint('/weights-{epoch:02d}.h5',
                                                monitor='val_capsnet_acc',
                                                save_best_only=True,
                                                save_weights_only=True,
                                                verbose=1)

lr_decay = tf.keras.callbacks.LearningRateScheduler(
    schedule=lambda epoch: 0.001 * (0.9**epoch))

# compile the model

train_model.compile(optimizer=tf.keras.optimizers.Adam(lr=0.001),
                    loss=[margin_loss, 'mse'],
                    loss_weights=[1., 0.392],
                    metrics={'capsnet': 'accuracy'})

train_model.summary()

train_model.fit([x_train, y_train], [y_train, x_train],
                batch_size=32,
                epochs=50,
                validation_data=[[x_test, y_test], [y_test, x_test]],
                callbacks=[log, checkpoint, lr_decay])

model.save_weights('/trained_model.h5')
print('Trained model saved to \'%s/trained_model.h5\'')

from utils import plot_log
plot_log('/log.csv', show=True)