Esempio n. 1
0
def create_callbacks(args, metrics):
    if args.snapshot is None:
        if args.checkpoint or args.history or args.tensorboard:
            flag = True
            while flag:
                try:
                    os.makedirs(
                        os.path.join(args.result_path, args.dataset,
                                     args.stamp))
                    flag = False
                except:
                    args.stamp = create_stamp()

            yaml.dump(vars(args),
                      open(
                          os.path.join(args.result_path, args.dataset,
                                       args.stamp, "model_desc.yml"), "w"),
                      default_flow_style=False)

    if args.checkpoint:
        os.makedirs(os.path.join(
            args.result_path,
            '{}/{}/checkpoint'.format(args.dataset, args.stamp)),
                    exist_ok=True)

    if args.history:
        os.makedirs(os.path.join(
            args.result_path, '{}/{}/history'.format(args.dataset,
                                                     args.stamp)),
                    exist_ok=True)
        csvlogger = pd.DataFrame(columns=['epoch'] + list(metrics.keys()))
        if os.path.isfile(
                os.path.join(
                    args.result_path,
                    '{}/{}/history/epoch.csv'.format(args.dataset,
                                                     args.stamp))):
            csvlogger = pd.read_csv(
                os.path.join(
                    args.result_path,
                    '{}/{}/history/epoch.csv'.format(args.dataset,
                                                     args.stamp)))
        else:
            csvlogger.to_csv(os.path.join(
                args.result_path,
                '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)),
                             index=False)
    else:
        csvlogger = None

    if args.tensorboard:
        train_writer = tf.summary.create_file_writer(
            os.path.join(args.result_path, args.dataset, args.stamp,
                         'logs/train'))
        val_writer = tf.summary.create_file_writer(
            os.path.join(args.result_path, args.dataset, args.stamp,
                         'logs/val'))
    else:
        train_writer = val_writer = None

    return csvlogger, train_writer, val_writer
Esempio n. 2
0
def main():
    set_seed()
    args = get_arguments()
    if args.task == 'pretext':
        if args.dataset == 'imagenet':
            args.lr = 0.5 * float(args.batch_size / 256)
        elif args.dataset == 'cifar10':
            args.lr = 0.03 * float(args.batch_size / 256)
    else:
        if args.dataset == 'imagenet' and args.freeze:
            args.lr = 30. * float(args.batch_size / 256)
        else:  # args.dataset == 'cifar10':
            args.lr = 1.8 * float(args.batch_size / 256)

    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        # strategy = tf.distribute.experimental.CentralStorageStrategy()
        strategy = tf.distribute.MirroredStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size //
                                                     num_workers))

    ##########################
    # Training
    ##########################
    if args.task == 'pretext':
        train_pretext(args, logger, initial_epoch, strategy, num_workers)
    else:
        train_lincls(args, logger, initial_epoch, strategy, num_workers)
Esempio n. 3
0
def main():
    set_seed()
    args = get_arguments()
    args.lr = args.lr or 1. * args.batch_size / 256
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.MirroredStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("BATCH SIZE PER WORKER : {}".format(args.batch_size //
                                                    num_workers))

    ##########################
    # Training
    ##########################
    if args.task == 'pretext':
        train_pixpro(args, logger, initial_epoch, strategy, num_workers)
    else:
        raise NotImplementedError()
Esempio n. 4
0
def main():
    set_seed()
    args = get_arguments()
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))

    ##########################
    # Training
    ##########################
    if args.task in ['v1', 'v2']:
        train_moco(args, logger, initial_epoch, strategy, num_workers)
    else:
        train_lincls(args, logger, initial_epoch, strategy, num_workers)
Esempio n. 5
0
def create_callbacks(args, logger, initial_epoch):
    if not args.resume:
        if args.checkpoint or args.history or args.tensorboard:
            if os.path.isdir(
                    f'{args.result_path}/{args.dataset}/{args.stamp}'):
                flag = input(
                    f'\n{args.dataset}/{args.stamp} is already saved. '
                    'Do you want new stamp? (y/n) ')
                if flag == 'y':
                    args.stamp = create_stamp()
                    initial_epoch = 0
                    logger.info(f'New stamp {args.stamp} will be created.')
                elif flag == 'n':
                    return -1, initial_epoch
                else:
                    logger.info(f'You must select \'y\' or \'n\'.')
                    return -2, initial_epoch

            os.makedirs(f'{args.result_path}/{args.dataset}/{args.stamp}')
            yaml.dump(
                vars(args),
                open(
                    f'{args.result_path}/{args.dataset}/{args.stamp}/model_desc.yml',
                    'w'),
                default_flow_style=False)
        else:
            logger.info(f'{args.stamp} is not created due to '
                        f'checkpoint - {args.checkpoint} | '
                        f'history - {args.history} | '
                        f'tensorboard - {args.tensorboard}')

    callbacks = []
    if args.checkpoint:
        os.makedirs(
            f'{args.result_path}/{args.dataset}/{args.stamp}/checkpoint',
            exist_ok=True)
        callbacks.append(
            ModelCheckpoint(filepath=os.path.join(
                f'{args.result_path}/{args.dataset}/{args.stamp}/checkpoint',
                '{epoch:04d}_{val_loss:.4f}_{val_acc1:.4f}_{val_acc5:.4f}.h5'),
                            monitor='val_acc1',
                            mode='max',
                            verbose=1,
                            save_weights_only=True))

    if args.history:
        os.makedirs(f'{args.result_path}/{args.dataset}/{args.stamp}/history',
                    exist_ok=True)
        callbacks.append(
            CSVLogger(
                filename=
                f'{args.result_path}/{args.dataset}/{args.stamp}/history/epoch.csv',
                separator=',',
                append=True))

    if args.tensorboard:
        callbacks.append(
            TensorBoard(
                log_dir=f'{args.result_path}/{args.dataset}/{args.stamp}/logs',
                histogram_freq=args.tb_histogram,
                write_graph=True,
                write_images=True,
                update_freq=args.tb_interval,
                profile_batch=2,
            ))

    if args.lr_scheduler:

        def scheduler(epoch):
            if epoch < 100:
                return 0.1
            elif epoch < 150:
                return 0.01
            else:
                return 0.001

        callbacks.append(LearningRateScheduler(schedule=scheduler, verbose=1))

    return callbacks, initial_epoch
Esempio n. 6
0
def create_callbacks(args, logger, initial_epoch):
    if not args.resume:
        if args.checkpoint or args.history:
            if os.path.isdir(f'{args.result_path}/{args.task}/{args.stamp}'):
                flag = input(f'\n{args.task}/{args.stamp} is already saved. '
                             'Do you want new stamp? (y/n) ')
                if flag == 'y':
                    args.stamp = create_stamp()
                    initial_epoch = 0
                    logger.info(f'New stamp {args.stamp} will be created.')
                elif flag == 'n':
                    return -1, initial_epoch
                else:
                    logger.info(f'You must select \'y\' or \'n\'.')
                    return -2, initial_epoch

            os.makedirs(f'{args.result_path}/{args.task}/{args.stamp}')
            yaml.dump(
                vars(args),
                open(
                    f'{args.result_path}/{args.task}/{args.stamp}/model_desc.yml',
                    'w'),
                default_flow_style=False)
        else:
            logger.info(f'{args.stamp} is not created due to '
                        f'checkpoint - {args.checkpoint} | '
                        f'history - {args.history} | ')

    callbacks = []
    if args.checkpoint:
        if args.task == 'pretext':
            callbacks.append(
                ModelCheckpoint(
                    filepath=
                    f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/latest.h5',
                    monitor='loss',
                    mode='min',
                    verbose=1,
                    save_weights_only=True))
            callbacks.append(
                ModelCheckpoint(
                    filepath=
                    f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/best.h5',
                    monitor='loss',
                    mode='min',
                    verbose=1,
                    save_weights_only=True,
                    save_best_only=True))
        else:
            callbacks.append(
                ModelCheckpoint(
                    filepath=
                    f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/latest.h5',
                    monitor='val_acc1',
                    mode='max',
                    verbose=1,
                    save_weights_only=True))
            callbacks.append(
                ModelCheckpoint(
                    filepath=
                    f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/best.h5',
                    monitor='val_acc1',
                    mode='max',
                    verbose=1,
                    save_weights_only=True,
                    save_best_only=True))

    if args.history:
        os.makedirs(f'{args.result_path}/{args.task}/{args.stamp}/history',
                    exist_ok=True)
        if args.task == 'pretext':
            callbacks.append(
                CustomCSVLogger(
                    filename=
                    f'{args.result_path}/{args.task}/{args.stamp}/history/epoch.csv',
                    separator=',',
                    append=True))
        else:
            callbacks.append(
                CSVLogger(
                    filename=
                    f'{args.result_path}/{args.task}/{args.stamp}/history/epoch.csv',
                    separator=',',
                    append=True))

    return callbacks, initial_epoch
Esempio n. 7
0
def main(args):
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))


    ##########################
    # Strategy
    ##########################
    # strategy = tf.distribute.MirroredStrategy()
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))
    logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size // strategy.num_replicas_in_sync))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))

    ##########################
    # Model & Metric & Generator
    ##########################
    # metrics
    metrics = {
        'loss'    :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss':   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
    }

    if args.loss == 'crossentropy':
        metrics.update({
            'acc1'      : tf.keras.metrics.TopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
            'acc5'      : tf.keras.metrics.TopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32),
            'val_acc1'  : tf.keras.metrics.TopKCategoricalAccuracy(1, 'val_acc1', dtype=tf.float32),
            'val_acc5'  : tf.keras.metrics.TopKCategoricalAccuracy(5, 'val_acc5', dtype=tf.float32)})

    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            model.summary()
            return

        # optimizer
        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        if args.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(lr_scheduler, momentum=.9, decay=.0001)
        elif args.optimizer == 'rmsprop':
            optimizer = tf.keras.optimizers.RMSprop(lr_scheduler)
        elif args.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(lr_scheduler)

        # loss & generator
        if args.loss == 'supcon':
            criterion = supervised_contrastive(args, args.batch_size // strategy.num_replicas_in_sync)
            train_generator = dataloader_supcon(args, trainset, 'train', args.batch_size)
            val_generator = dataloader_supcon(args, valset, 'train', args.batch_size, shuffle=False)
        elif args.loss == 'crossentropy':
            criterion = crossentropy(args)
            train_generator = dataloader(args, trainset, 'train', args.batch_size)
            val_generator = dataloader(args, valset, 'val', args.batch_size, shuffle=False)
        else:
            raise ValueError()    
        
        train_generator = strategy.experimental_distribute_dataset(train_generator)
        val_generator = strategy.experimental_distribute_dataset(val_generator)

    csvlogger, train_writer, val_writer = create_callbacks(args, metrics)
    logger.info("Build Model & Metrics")

    ##########################
    # READY Train
    ##########################
    train_iterator = iter(train_generator)
    val_iterator = iter(val_generator)
        
    # @tf.function
    def do_step(iterator, mode):
        def get_loss(inputs, labels, training=True):
            logits = tf.cast(model(inputs, training=training), tf.float32)
            loss = criterion(labels, logits)
            loss_mean = tf.nn.compute_average_loss(loss, global_batch_size=args.batch_size)
            return logits, loss, loss_mean

        def step_fn(from_iterator):
            if args.loss == 'supcon':
                (img1, img2), labels = from_iterator
                inputs = tf.concat([img1, img2], axis=0)
            else:
                inputs, labels = from_iterator
            
            if mode == 'train':
                with tf.GradientTape() as tape:
                    logits, loss, loss_mean = get_loss(inputs, labels)

                grads = tape.gradient(loss_mean, model.trainable_variables)
                optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
            else:
                logits, loss, loss_mean = get_loss(inputs, labels, training=False)

            if args.loss == 'crossentropy':
                metrics['acc' if mode == 'train' else 'val_acc'].update_state(labels, logits)

            return loss

        loss_per_replica = strategy.run(step_fn, args=(next(iterator),))
        loss_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_per_replica, axis=0)
        metrics['loss' if mode == 'train' else 'val_loss'].update_state(loss_mean)
        

    ##########################
    # Train
    ##########################
    for epoch in range(initial_epoch, args.epochs):
        print('\nEpoch {}/{}'.format(epoch+1, args.epochs))
        print('Learning Rate : {}'.format(optimizer.learning_rate(optimizer.iterations)))

        # train
        print('Train')
        progBar_train = tf.keras.utils.Progbar(steps_per_epoch, stateful_metrics=metrics.keys())
        for step in range(steps_per_epoch):
            do_step(train_iterator, 'train')
            progBar_train.update(step, values=[(k, v.result()) for k, v in metrics.items() if not 'val' in k])

            if args.tensorboard and args.tb_interval > 0:
                if (epoch*steps_per_epoch+step) % args.tb_interval == 0:
                    with train_writer.as_default():
                        for k, v in metrics.items():
                            if not 'val' in k:
                                tf.summary.scalar(k, v.result(), step=epoch*steps_per_epoch+step)

        if args.tensorboard and args.tb_interval == 0:
            with train_writer.as_default():
                for k, v in metrics.items():
                    if not 'val' in k:
                        tf.summary.scalar(k, v.result(), step=epoch)

        # val
        print('\n\nValidation')
        progBar_val = tf.keras.utils.Progbar(validation_steps, stateful_metrics=metrics.keys())
        for step in range(validation_steps):
            do_step(val_iterator, 'val')
            progBar_val.update(step, values=[(k, v.result()) for k, v in metrics.items() if 'val' in k])
    
        # logs
        logs = {k: v.result().numpy() for k, v in metrics.items()}
        logs['epoch'] = epoch + 1

        if args.checkpoint:
            if args.loss == 'supcon':
                ckpt_path = '{:04d}_{:.4f}.h5'.format(epoch+1, logs['val_loss'])
            else:
                ckpt_path = '{:04d}_{:.4f}_{:.4f}.h5'.format(epoch+1, logs['val_acc'], logs['val_loss'])

            model.save_weights(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path))

            print('\nSaved at {}'.format(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path)))

        if args.history:
            csvlogger = csvlogger.append(logs, ignore_index=True)
            csvlogger.to_csv(os.path.join(args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)), index=False)

        if args.tensorboard:
            with train_writer.as_default():
                tf.summary.scalar('loss', metrics['loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('acc', metrics['acc'].result(), step=epoch)

            with val_writer.as_default():
                tf.summary.scalar('val_loss', metrics['val_loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('val_acc', metrics['val_acc'].result(), step=epoch)
        
        for k, v in metrics.items():
            v.reset_states()
Esempio n. 8
0
def main(args=None):
    set_seed()
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    strategy = tf.distribute.MirroredStrategy()
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.data_path, args.dataset)
    if args.steps is not None:
        steps_per_epoch = args.steps
    elif args.dataset == 'cifar10':
        steps_per_epoch = 50000 // args.batch_size
        validation_steps = 10000 // args.batch_size
    elif args.dataset == 'svhn':
        steps_per_epoch = 73257 // args.batch_size
        validation_steps = 26032 // args.batch_size
    elif args.dataset == 'imagenet':
        steps_per_epoch = len(trainset) // args.batch_size
        validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))


    ##########################
    # Model & Metric & Generator
    ##########################
    metrics = {
        'acc'       :   tf.keras.metrics.CategoricalAccuracy('acc', dtype=tf.float32),
        'val_acc'   :   tf.keras.metrics.CategoricalAccuracy('val_acc', dtype=tf.float32),
        'loss'      :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss'  :   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
        'total_loss':   tf.keras.metrics.Mean('total_loss', dtype=tf.float32),
        'unsup_loss':   tf.keras.metrics.Mean('unsup_loss', dtype=tf.float32)}
    
    with strategy.scope():
        model = 
Esempio n. 9
0
def main():
    args = get_arguments()
    set_seed(args.seed)
    args.classes = CLASS_DICT[args.dataset]
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")


    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info(f"{strategy.__class__.__name__} : {num_workers}")
    logger.info(f"GLOBAL BATCH SIZE : {args.batch_size}")


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.dataset, args.classes, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = set_model(args.backbone, args.dataset, args.classes)
        if args.snapshot:
            model.load_weights(args.snapshot)
            logger.info(f"Load weights at {args.snapshot}")

        model.compile(
            loss=args.loss,
            optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc1'),
                tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc5')],
            xe_loss=tf.keras.losses.categorical_crossentropy,
            cls_loss=tf.keras.losses.KLD,
            cls_lambda=args.loss_weight,
            temperature=args.temperature,
            num_workers=num_workers,
            run_eagerly=True)


    ##########################
    # Generator
    ##########################
    train_generator = DataLoader(
        loss=args.loss,
        mode='train', 
        datalist=trainset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=True).dataloader()

    val_generator = DataLoader(
        loss='crossentropy',
        mode='val', 
        datalist=valset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=False).dataloader()


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,)