示例#1
0
def train_lincls(args, logger, initial_epoch, strategy, num_workers):
    # assert args.snapshot is not None, 'pretrained weight is needed!'
    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.task, args.dataset, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model & Generator
    ##########################
    train_generator = DataLoader(args, args.task, 'train', trainset, args.batch_size, num_workers).dataloader
    val_generator = DataLoader(args, args.task, 'val', valset, args.batch_size, num_workers).dataloader
        
    with strategy.scope():
        backbone = SimSiam(args, logger)
        model = set_lincls(args, backbone.encoder)
        if args.resume and args.snapshot:
            model.load_weights(args.snapshot)
            logger.info('Load weights at {}'.format(args.snapshot))

        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        model.compile(
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
                     tf.keras.metrics.SparseTopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32)],
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, name='loss'),
            run_eagerly=False)


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps)
示例#2
0
def train_pretext(args, logger, initial_epoch, strategy, num_workers):
    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.task, args.dataset, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")


    ##########################
    # Model & Generator
    ##########################
    with strategy.scope():
        model = BarlowTwins(args, logger, num_workers=num_workers)
        if args.summary:
            model.build((None, args.img_size, args.img_size, 3))
            model.summary()
            return 
        
        # Load checkpoints
        if args.snapshot:
            model.build((None, args.img_size, args.img_size, 3))
            model.load_weights(args.snapshot)
            logger.info('Load weights at {}'.format(args.snapshot))

        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        model.compile(
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            loss=tf.keras.losses.cosine_similarity,
            run_eagerly=False)

    train_generator = DataLoader(args, args.task, 'train', trainset, args.batch_size, num_workers).dataloader


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,)
示例#3
0
文件: main.py 项目: ymcidence/MoCo-TF
def train_moco(args, logger, initial_epoch, strategy, num_workers):
    ##########################
    # Dataset
    ##########################
    trainset = set_dataset(args.task, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    ##########################
    # Model & Generator
    ##########################
    with strategy.scope():
        model = MoCo(args, logger)

        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch,
                                                    initial_epoch)
        model.compile(
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(1,
                                                         'acc1',
                                                         dtype=tf.float32),
                tf.keras.metrics.TopKCategoricalAccuracy(5,
                                                         'acc5',
                                                         dtype=tf.float32)
            ],
            num_workers=num_workers,
            run_eagerly=True)

    train_generator = DataLoader(args, 'train', trainset, args.batch_size,
                                 num_workers).dataloader

    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
    )
示例#4
0
def main():
    args = set_default(get_argument())
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
        temp = datetime.now()
        args.stamp = "{:02d}{:02d}{:02d}_{}_{:02d}_{:02d}_{:02d}".format(
            temp.year // 100,
            temp.month,
            temp.day,
            weekday[temp.weekday()],
            temp.hour,
            temp.minute,
            temp.second,
        )

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Generator
    ##########################
    trainset, valset = set_dataset(args, logger)

    train_generator = create_generator(args, trainset, "train", args.batch_size)
    # for t in train_generator:
        # print(sorted(t[1]['main_output'].numpy().argmax(axis=0)))
    #     print(t[0]['main_input'].shape, t[0]['main_input'].numpy().min(), t[0]['main_input'].numpy().max(), t[1]['main_output'])
    val_generator = create_generator(args, valset, "val", args.batch_size)
    # for t in val_generator:
    #     print(t[0][0].shape, t[0][1], t[1])
    test_generator1 = create_generator(args, trainset, "val", 1)
    test_generator2 = create_generator(args, valset, "val", 1)
    # for t in test_generator2:
    # print(t[0]['main_input'].shape, t[0]['arcface_input'])

    if args.class_weight:
        assert args.classes > 1
        from sklearn.utils.class_weight import compute_class_weight

        train_label = trainset[:, 1:].astype(np.int).argmax(axis=1)
        class_weight = compute_class_weight(
            class_weight="balanced", classes=np.unique(train_label), y=train_label
        )

    else:
        class_weight = None

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    logger.info("    --> {}".format(steps_per_epoch))
    # logger.info("    --> {}".format(trainset[:, 2:].sum(axis=0)))
    # logger.info("    --> {}".format(class_weight))

    logger.info("=========== valset ===========")
    validation_steps = len(valset) // args.batch_size
    logger.info("    --> {}".format(validation_steps))
    # logger.info("    --> {}".format(valset[:, 2:].sum(axis=0)))

    ##########################
    # Model
    ##########################
    model = create_model(args, logger)
    if args.summary:
        model.summary()
        print(model.inputs[0])
        print(model.get_layer(name="fc2"))
        return

    model = compile_model(args, model, steps_per_epoch)
    logger.info("Build model!")

    ##########################
    # Callbacks
    ##########################
    callbacks = create_callbacks(args, test_generator1, test_generator2, trainset, valset)
    logger.info("Build callbacks!")

    ##########################
    # Train
    ##########################
    model.fit(
        x=train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        validation_data=val_generator,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        class_weight=class_weight,
        initial_epoch=initial_epoch,
        verbose=args.verbose,
    )
示例#5
0
def main(args):
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))


    ##########################
    # Strategy
    ##########################
    # strategy = tf.distribute.MirroredStrategy()
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))
    logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size // strategy.num_replicas_in_sync))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))

    ##########################
    # Model & Metric & Generator
    ##########################
    # metrics
    metrics = {
        'loss'    :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss':   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
    }

    if args.loss == 'crossentropy':
        metrics.update({
            'acc1'      : tf.keras.metrics.TopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
            'acc5'      : tf.keras.metrics.TopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32),
            'val_acc1'  : tf.keras.metrics.TopKCategoricalAccuracy(1, 'val_acc1', dtype=tf.float32),
            'val_acc5'  : tf.keras.metrics.TopKCategoricalAccuracy(5, 'val_acc5', dtype=tf.float32)})

    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            model.summary()
            return

        # optimizer
        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        if args.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(lr_scheduler, momentum=.9, decay=.0001)
        elif args.optimizer == 'rmsprop':
            optimizer = tf.keras.optimizers.RMSprop(lr_scheduler)
        elif args.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(lr_scheduler)

        # loss & generator
        if args.loss == 'supcon':
            criterion = supervised_contrastive(args, args.batch_size // strategy.num_replicas_in_sync)
            train_generator = dataloader_supcon(args, trainset, 'train', args.batch_size)
            val_generator = dataloader_supcon(args, valset, 'train', args.batch_size, shuffle=False)
        elif args.loss == 'crossentropy':
            criterion = crossentropy(args)
            train_generator = dataloader(args, trainset, 'train', args.batch_size)
            val_generator = dataloader(args, valset, 'val', args.batch_size, shuffle=False)
        else:
            raise ValueError()    
        
        train_generator = strategy.experimental_distribute_dataset(train_generator)
        val_generator = strategy.experimental_distribute_dataset(val_generator)

    csvlogger, train_writer, val_writer = create_callbacks(args, metrics)
    logger.info("Build Model & Metrics")

    ##########################
    # READY Train
    ##########################
    train_iterator = iter(train_generator)
    val_iterator = iter(val_generator)
        
    # @tf.function
    def do_step(iterator, mode):
        def get_loss(inputs, labels, training=True):
            logits = tf.cast(model(inputs, training=training), tf.float32)
            loss = criterion(labels, logits)
            loss_mean = tf.nn.compute_average_loss(loss, global_batch_size=args.batch_size)
            return logits, loss, loss_mean

        def step_fn(from_iterator):
            if args.loss == 'supcon':
                (img1, img2), labels = from_iterator
                inputs = tf.concat([img1, img2], axis=0)
            else:
                inputs, labels = from_iterator
            
            if mode == 'train':
                with tf.GradientTape() as tape:
                    logits, loss, loss_mean = get_loss(inputs, labels)

                grads = tape.gradient(loss_mean, model.trainable_variables)
                optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
            else:
                logits, loss, loss_mean = get_loss(inputs, labels, training=False)

            if args.loss == 'crossentropy':
                metrics['acc' if mode == 'train' else 'val_acc'].update_state(labels, logits)

            return loss

        loss_per_replica = strategy.run(step_fn, args=(next(iterator),))
        loss_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_per_replica, axis=0)
        metrics['loss' if mode == 'train' else 'val_loss'].update_state(loss_mean)
        

    ##########################
    # Train
    ##########################
    for epoch in range(initial_epoch, args.epochs):
        print('\nEpoch {}/{}'.format(epoch+1, args.epochs))
        print('Learning Rate : {}'.format(optimizer.learning_rate(optimizer.iterations)))

        # train
        print('Train')
        progBar_train = tf.keras.utils.Progbar(steps_per_epoch, stateful_metrics=metrics.keys())
        for step in range(steps_per_epoch):
            do_step(train_iterator, 'train')
            progBar_train.update(step, values=[(k, v.result()) for k, v in metrics.items() if not 'val' in k])

            if args.tensorboard and args.tb_interval > 0:
                if (epoch*steps_per_epoch+step) % args.tb_interval == 0:
                    with train_writer.as_default():
                        for k, v in metrics.items():
                            if not 'val' in k:
                                tf.summary.scalar(k, v.result(), step=epoch*steps_per_epoch+step)

        if args.tensorboard and args.tb_interval == 0:
            with train_writer.as_default():
                for k, v in metrics.items():
                    if not 'val' in k:
                        tf.summary.scalar(k, v.result(), step=epoch)

        # val
        print('\n\nValidation')
        progBar_val = tf.keras.utils.Progbar(validation_steps, stateful_metrics=metrics.keys())
        for step in range(validation_steps):
            do_step(val_iterator, 'val')
            progBar_val.update(step, values=[(k, v.result()) for k, v in metrics.items() if 'val' in k])
    
        # logs
        logs = {k: v.result().numpy() for k, v in metrics.items()}
        logs['epoch'] = epoch + 1

        if args.checkpoint:
            if args.loss == 'supcon':
                ckpt_path = '{:04d}_{:.4f}.h5'.format(epoch+1, logs['val_loss'])
            else:
                ckpt_path = '{:04d}_{:.4f}_{:.4f}.h5'.format(epoch+1, logs['val_acc'], logs['val_loss'])

            model.save_weights(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path))

            print('\nSaved at {}'.format(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path)))

        if args.history:
            csvlogger = csvlogger.append(logs, ignore_index=True)
            csvlogger.to_csv(os.path.join(args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)), index=False)

        if args.tensorboard:
            with train_writer.as_default():
                tf.summary.scalar('loss', metrics['loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('acc', metrics['acc'].result(), step=epoch)

            with val_writer.as_default():
                tf.summary.scalar('val_loss', metrics['val_loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('val_acc', metrics['val_acc'].result(), step=epoch)
        
        for k, v in metrics.items():
            v.reset_states()
示例#6
0
def main():
    set_seed()
    args = get_arguments()
    assert args.model_name is not None, 'model_name must be set.'

    logger = get_logger("MyLogger")
    args, initial_epoch = set_cfg(args, logger)
    if initial_epoch == -1:
        # training was already finished!
        return

    get_session(args)
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))


    ##########################
    # Strategy
    ##########################
    # strategy = tf.distribute.MirroredStrategy()
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))


    ##########################
    # Generator
    ##########################
    trainset, valset = set_dataset(args)
    train_generator = dataloader(args, trainset, 'train')
    val_generator = dataloader(args, valset, 'val', shuffle=False)
    
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size
    
    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            from tensorflow.keras.utils import plot_model
            plot_model(model, to_file=os.path.join(args.src_path, 'model.png'), show_shapes=True)
            model.summary(line_length=130)
            return

        # optimizer
        scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        optimizer = tf.keras.optimizers.SGD(scheduler, momentum=.9, decay=.00005)

        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.categorical_crossentropy,
            metrics=['acc']
        )


    ##########################
    # Callbacks
    ##########################
    callbacks = create_callbacks(
        args, 
        path=os.path.join(args.result_path, args.dataset, args.model_name, str(args.stamp)))
    logger.info("Build callbacks!")


    ##########################
    # Train
    ##########################
    model.fit(
        x=train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        validation_data=val_generator,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        initial_epoch=initial_epoch,
        verbose=1,
    )
示例#7
0
def train_pixpro(args, logger, initial_epoch, strategy, num_workers):
    ##########################
    # Dataset
    ##########################
    trainset = set_dataset(args.task, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    ##########################
    # Model & Generator
    ##########################
    train_generator = DataLoader(args, 'train', trainset,
                                 args.batch_size).dataloader
    with strategy.scope():
        model = PixPro(logger,
                       norm='bn' if num_workers == 1 else 'syncbn',
                       channel=256,
                       gamma=args.gamma,
                       num_layers=args.num_layers,
                       snapshot=args.snapshot)

        if args.summary:
            model.summary()
            return

        lr_scheduler = OptionalLearningRateSchedule(
            lr=args.lr,
            lr_mode=args.lr_mode,
            lr_interval=args.lr_interval,
            lr_value=args.lr_value,
            total_epochs=args.epochs,
            steps_per_epoch=steps_per_epoch,
            initial_epoch=initial_epoch)

        model.compile(
            # TODO : apply LARS
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            batch_size=args.batch_size,
            num_workers=num_workers,
            run_eagerly=None)

    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
    )
示例#8
0
def main():
    args = get_arguments()
    set_seed(args.seed)
    args.classes = CLASS_DICT[args.dataset]
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")


    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info(f"{strategy.__class__.__name__} : {num_workers}")
    logger.info(f"GLOBAL BATCH SIZE : {args.batch_size}")


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.dataset, args.classes, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = set_model(args.backbone, args.dataset, args.classes)
        if args.snapshot:
            model.load_weights(args.snapshot)
            logger.info(f"Load weights at {args.snapshot}")

        model.compile(
            loss=args.loss,
            optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc1'),
                tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc5')],
            xe_loss=tf.keras.losses.categorical_crossentropy,
            cls_loss=tf.keras.losses.KLD,
            cls_lambda=args.loss_weight,
            temperature=args.temperature,
            num_workers=num_workers,
            run_eagerly=True)


    ##########################
    # Generator
    ##########################
    train_generator = DataLoader(
        loss=args.loss,
        mode='train', 
        datalist=trainset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=True).dataloader()

    val_generator = DataLoader(
        loss='crossentropy',
        mode='val', 
        datalist=valset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=False).dataloader()


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,)