Beispiel #1
0
def main():
    set_seed()
    args = get_arguments()
    if args.task == 'pretext':
        if args.dataset == 'imagenet':
            args.lr = 0.5 * float(args.batch_size / 256)
        elif args.dataset == 'cifar10':
            args.lr = 0.03 * float(args.batch_size / 256)
    else:
        if args.dataset == 'imagenet' and args.freeze:
            args.lr = 30. * float(args.batch_size / 256)
        else:  # args.dataset == 'cifar10':
            args.lr = 1.8 * float(args.batch_size / 256)

    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        # strategy = tf.distribute.experimental.CentralStorageStrategy()
        strategy = tf.distribute.MirroredStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size //
                                                     num_workers))

    ##########################
    # Training
    ##########################
    if args.task == 'pretext':
        train_pretext(args, logger, initial_epoch, strategy, num_workers)
    else:
        train_lincls(args, logger, initial_epoch, strategy, num_workers)
Beispiel #2
0
def main():
    set_seed()
    args = get_arguments()
    args.lr = args.lr or 1. * args.batch_size / 256
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.MirroredStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("BATCH SIZE PER WORKER : {}".format(args.batch_size //
                                                    num_workers))

    ##########################
    # Training
    ##########################
    if args.task == 'pretext':
        train_pixpro(args, logger, initial_epoch, strategy, num_workers)
    else:
        raise NotImplementedError()
Beispiel #3
0
def main():
    set_seed()
    args = get_arguments()
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")

    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))

    ##########################
    # Training
    ##########################
    if args.task in ['v1', 'v2']:
        train_moco(args, logger, initial_epoch, strategy, num_workers)
    else:
        train_lincls(args, logger, initial_epoch, strategy, num_workers)
Beispiel #4
0
def main():
    args = set_default(get_argument())
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
        temp = datetime.now()
        args.stamp = "{:02d}{:02d}{:02d}_{}_{:02d}_{:02d}_{:02d}".format(
            temp.year // 100,
            temp.month,
            temp.day,
            weekday[temp.weekday()],
            temp.hour,
            temp.minute,
            temp.second,
        )

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Generator
    ##########################
    trainset, valset = set_dataset(args, logger)

    train_generator = create_generator(args, trainset, "train", args.batch_size)
    # for t in train_generator:
        # print(sorted(t[1]['main_output'].numpy().argmax(axis=0)))
    #     print(t[0]['main_input'].shape, t[0]['main_input'].numpy().min(), t[0]['main_input'].numpy().max(), t[1]['main_output'])
    val_generator = create_generator(args, valset, "val", args.batch_size)
    # for t in val_generator:
    #     print(t[0][0].shape, t[0][1], t[1])
    test_generator1 = create_generator(args, trainset, "val", 1)
    test_generator2 = create_generator(args, valset, "val", 1)
    # for t in test_generator2:
    # print(t[0]['main_input'].shape, t[0]['arcface_input'])

    if args.class_weight:
        assert args.classes > 1
        from sklearn.utils.class_weight import compute_class_weight

        train_label = trainset[:, 1:].astype(np.int).argmax(axis=1)
        class_weight = compute_class_weight(
            class_weight="balanced", classes=np.unique(train_label), y=train_label
        )

    else:
        class_weight = None

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    logger.info("    --> {}".format(steps_per_epoch))
    # logger.info("    --> {}".format(trainset[:, 2:].sum(axis=0)))
    # logger.info("    --> {}".format(class_weight))

    logger.info("=========== valset ===========")
    validation_steps = len(valset) // args.batch_size
    logger.info("    --> {}".format(validation_steps))
    # logger.info("    --> {}".format(valset[:, 2:].sum(axis=0)))

    ##########################
    # Model
    ##########################
    model = create_model(args, logger)
    if args.summary:
        model.summary()
        print(model.inputs[0])
        print(model.get_layer(name="fc2"))
        return

    model = compile_model(args, model, steps_per_epoch)
    logger.info("Build model!")

    ##########################
    # Callbacks
    ##########################
    callbacks = create_callbacks(args, test_generator1, test_generator2, trainset, valset)
    logger.info("Build callbacks!")

    ##########################
    # Train
    ##########################
    model.fit(
        x=train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        validation_data=val_generator,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        class_weight=class_weight,
        initial_epoch=initial_epoch,
        verbose=args.verbose,
    )
Beispiel #5
0
def main(args):
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))


    ##########################
    # Strategy
    ##########################
    # strategy = tf.distribute.MirroredStrategy()
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))
    logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size // strategy.num_replicas_in_sync))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))

    ##########################
    # Model & Metric & Generator
    ##########################
    # metrics
    metrics = {
        'loss'    :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss':   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
    }

    if args.loss == 'crossentropy':
        metrics.update({
            'acc1'      : tf.keras.metrics.TopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
            'acc5'      : tf.keras.metrics.TopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32),
            'val_acc1'  : tf.keras.metrics.TopKCategoricalAccuracy(1, 'val_acc1', dtype=tf.float32),
            'val_acc5'  : tf.keras.metrics.TopKCategoricalAccuracy(5, 'val_acc5', dtype=tf.float32)})

    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            model.summary()
            return

        # optimizer
        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        if args.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(lr_scheduler, momentum=.9, decay=.0001)
        elif args.optimizer == 'rmsprop':
            optimizer = tf.keras.optimizers.RMSprop(lr_scheduler)
        elif args.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(lr_scheduler)

        # loss & generator
        if args.loss == 'supcon':
            criterion = supervised_contrastive(args, args.batch_size // strategy.num_replicas_in_sync)
            train_generator = dataloader_supcon(args, trainset, 'train', args.batch_size)
            val_generator = dataloader_supcon(args, valset, 'train', args.batch_size, shuffle=False)
        elif args.loss == 'crossentropy':
            criterion = crossentropy(args)
            train_generator = dataloader(args, trainset, 'train', args.batch_size)
            val_generator = dataloader(args, valset, 'val', args.batch_size, shuffle=False)
        else:
            raise ValueError()    
        
        train_generator = strategy.experimental_distribute_dataset(train_generator)
        val_generator = strategy.experimental_distribute_dataset(val_generator)

    csvlogger, train_writer, val_writer = create_callbacks(args, metrics)
    logger.info("Build Model & Metrics")

    ##########################
    # READY Train
    ##########################
    train_iterator = iter(train_generator)
    val_iterator = iter(val_generator)
        
    # @tf.function
    def do_step(iterator, mode):
        def get_loss(inputs, labels, training=True):
            logits = tf.cast(model(inputs, training=training), tf.float32)
            loss = criterion(labels, logits)
            loss_mean = tf.nn.compute_average_loss(loss, global_batch_size=args.batch_size)
            return logits, loss, loss_mean

        def step_fn(from_iterator):
            if args.loss == 'supcon':
                (img1, img2), labels = from_iterator
                inputs = tf.concat([img1, img2], axis=0)
            else:
                inputs, labels = from_iterator
            
            if mode == 'train':
                with tf.GradientTape() as tape:
                    logits, loss, loss_mean = get_loss(inputs, labels)

                grads = tape.gradient(loss_mean, model.trainable_variables)
                optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
            else:
                logits, loss, loss_mean = get_loss(inputs, labels, training=False)

            if args.loss == 'crossentropy':
                metrics['acc' if mode == 'train' else 'val_acc'].update_state(labels, logits)

            return loss

        loss_per_replica = strategy.run(step_fn, args=(next(iterator),))
        loss_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_per_replica, axis=0)
        metrics['loss' if mode == 'train' else 'val_loss'].update_state(loss_mean)
        

    ##########################
    # Train
    ##########################
    for epoch in range(initial_epoch, args.epochs):
        print('\nEpoch {}/{}'.format(epoch+1, args.epochs))
        print('Learning Rate : {}'.format(optimizer.learning_rate(optimizer.iterations)))

        # train
        print('Train')
        progBar_train = tf.keras.utils.Progbar(steps_per_epoch, stateful_metrics=metrics.keys())
        for step in range(steps_per_epoch):
            do_step(train_iterator, 'train')
            progBar_train.update(step, values=[(k, v.result()) for k, v in metrics.items() if not 'val' in k])

            if args.tensorboard and args.tb_interval > 0:
                if (epoch*steps_per_epoch+step) % args.tb_interval == 0:
                    with train_writer.as_default():
                        for k, v in metrics.items():
                            if not 'val' in k:
                                tf.summary.scalar(k, v.result(), step=epoch*steps_per_epoch+step)

        if args.tensorboard and args.tb_interval == 0:
            with train_writer.as_default():
                for k, v in metrics.items():
                    if not 'val' in k:
                        tf.summary.scalar(k, v.result(), step=epoch)

        # val
        print('\n\nValidation')
        progBar_val = tf.keras.utils.Progbar(validation_steps, stateful_metrics=metrics.keys())
        for step in range(validation_steps):
            do_step(val_iterator, 'val')
            progBar_val.update(step, values=[(k, v.result()) for k, v in metrics.items() if 'val' in k])
    
        # logs
        logs = {k: v.result().numpy() for k, v in metrics.items()}
        logs['epoch'] = epoch + 1

        if args.checkpoint:
            if args.loss == 'supcon':
                ckpt_path = '{:04d}_{:.4f}.h5'.format(epoch+1, logs['val_loss'])
            else:
                ckpt_path = '{:04d}_{:.4f}_{:.4f}.h5'.format(epoch+1, logs['val_acc'], logs['val_loss'])

            model.save_weights(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path))

            print('\nSaved at {}'.format(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path)))

        if args.history:
            csvlogger = csvlogger.append(logs, ignore_index=True)
            csvlogger.to_csv(os.path.join(args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)), index=False)

        if args.tensorboard:
            with train_writer.as_default():
                tf.summary.scalar('loss', metrics['loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('acc', metrics['acc'].result(), step=epoch)

            with val_writer.as_default():
                tf.summary.scalar('val_loss', metrics['val_loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('val_acc', metrics['val_acc'].result(), step=epoch)
        
        for k, v in metrics.items():
            v.reset_states()
Beispiel #6
0
def main():
    args = set_default(get_argument())
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        weekday = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
        temp = datetime.now()
        args.stamp = "{:02d}{:02d}{:02d}_{}_{:02d}_{:02d}_{:02d}".format(
            temp.year // 100,
            temp.month,
            temp.day,
            weekday[temp.weekday()],
            temp.hour,
            temp.minute,
            temp.second,
        )

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Generator
    ##########################
    trainset, valset, testset = set_dataset(args, logger)
    train_generator = dataloader(args, trainset, 'train')
    val_generator = dataloader(args, valset, 'val', False)

    # for t in train_generator:
    #     print(t[0]['main_input'].shape, t[0]['main_input'].numpy().min(), t[0]['main_input'].numpy().max(),
    #           t[1]['main_output'].shape, t[1]['main_output'].numpy().min(), t[1]['main_output'].numpy().max(),
    #           t[1]['main_output'].numpy().argmax())
    #     print()

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    validation_steps = len(valset) // args.batch_size
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))

    ##########################
    # Model
    ##########################
    model = set_model.Backbone(args, logger)
    if args.summary:
        model.summary()
        return

    logger.info("Build model!")

    ##########################
    # Metric
    ##########################
    metrics = {
        'loss':
        tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'acc':
        tf.keras.metrics.CategoricalAccuracy('acc', dtype=tf.float32),
        'val_loss':
        tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
        'val_acc':
        tf.keras.metrics.CategoricalAccuracy('val_acc', dtype=tf.float32),
    }

    csvlogger, lr_scheduler = create_callbacks(args, steps_per_epoch, metrics)
    if args.optimizer == 'sgd':
        optimizer = tf.keras.optimizers.SGD(lr_scheduler,
                                            momentum=.9,
                                            decay=.0001)
    elif args.optimizer == 'adam':
        optimizer = tf.keras.optimizers.Adam(lr_scheduler)

    ##########################
    # Train
    ##########################
    # steps_per_epoch = 10
    # validation_steps = 10
    train_iterator = iter(train_generator)
    val_iterator = iter(val_generator)

    progress_desc_train = 'Train : Loss {:.4f} | Acc {:.4f}'
    progress_desc_val = 'Val : Loss {:.4f} | Acc {:.4f}'

    for epoch in range(initial_epoch, args.epochs):
        print('\nEpoch {}/{}'.format(epoch + 1, args.epochs))
        print('Learning Rate : {}'.format(
            optimizer.learning_rate(optimizer.iterations)))

        progressbar_train = tqdm.tqdm(tf.range(steps_per_epoch),
                                      desc=progress_desc_train.format(
                                          0, 0, 0, 0),
                                      leave=True)
        for step in progressbar_train:
            inputs = next(train_iterator)
            img = inputs[0]['main_input']
            label = inputs[1]['main_output']
            with tf.GradientTape() as tape:
                logits = tf.cast(model(img, training=True), tf.float32)
                loss = tf.keras.losses.categorical_crossentropy(label, logits)
                loss = tf.reduce_mean(loss)

            grads = tape.gradient(loss, model.trainable_variables)
            optimizer.apply_gradients(zip(grads, model.trainable_variables))

            metrics['loss'].update_state(loss)
            metrics['acc'].update_state(label, logits)

            progressbar_train.set_description(
                progress_desc_train.format(metrics['loss'].result(),
                                           metrics['acc'].result()))
            progressbar_train.refresh()

        progressbar_val = tqdm.tqdm(tf.range(validation_steps),
                                    desc=progress_desc_val.format(0, 0),
                                    leave=True)
        for step in progressbar_val:
            val_inputs = next(val_iterator)
            val_img = val_inputs[0]['main_input']
            val_label = val_inputs[1]['main_output']
            val_logits = tf.cast(model(val_img, training=False), tf.float32)

            val_loss = tf.keras.losses.categorical_crossentropy(
                val_label, val_logits)
            val_loss = tf.reduce_mean(val_loss)

            metrics['val_loss'].update_state(val_loss)
            metrics['val_acc'].update_state(val_label, val_logits)

            progressbar_val.set_description(
                progress_desc_val.format(metrics['val_loss'].result(),
                                         metrics['val_acc'].result()))
            progressbar_val.refresh()

        logs = {k: v.result().numpy() for k, v in metrics.items()}
        logs['epoch'] = epoch + 1

        if args.checkpoint:
            model.save_weights(
                os.path.join(
                    args.result_path,
                    '{}/checkpoint/{:04d}_{:.4f}_{:.4f}.h5'.format(
                        args.stamp, epoch + 1, logs['val_acc'],
                        logs['val_loss'])))

            print('\nSaved at {}'.format(
                os.path.join(
                    args.result_path,
                    '{}/checkpoint/{:04d}_{:.4f}_{:.4f}.h5'.format(
                        args.stamp, epoch + 1, logs['val_acc'],
                        logs['val_loss']))))

        if args.history:
            csvlogger = csvlogger.append(logs, ignore_index=True)
            csvlogger.to_csv(os.path.join(
                args.result_path, '{}/history/epoch.csv'.format(args.stamp)),
                             index=False)

        for k, v in metrics.items():
            v.reset_states()
Beispiel #7
0
def main(args=None):
    set_seed()
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    strategy = tf.distribute.MirroredStrategy()
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.data_path, args.dataset)
    if args.steps is not None:
        steps_per_epoch = args.steps
    elif args.dataset == 'cifar10':
        steps_per_epoch = 50000 // args.batch_size
        validation_steps = 10000 // args.batch_size
    elif args.dataset == 'svhn':
        steps_per_epoch = 73257 // args.batch_size
        validation_steps = 26032 // args.batch_size
    elif args.dataset == 'imagenet':
        steps_per_epoch = len(trainset) // args.batch_size
        validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))


    ##########################
    # Model & Metric & Generator
    ##########################
    metrics = {
        'acc'       :   tf.keras.metrics.CategoricalAccuracy('acc', dtype=tf.float32),
        'val_acc'   :   tf.keras.metrics.CategoricalAccuracy('val_acc', dtype=tf.float32),
        'loss'      :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss'  :   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
        'total_loss':   tf.keras.metrics.Mean('total_loss', dtype=tf.float32),
        'unsup_loss':   tf.keras.metrics.Mean('unsup_loss', dtype=tf.float32)}
    
    with strategy.scope():
        model = 
Beispiel #8
0
def main():
    args = get_arguments()
    set_seed(args.seed)
    args.classes = CLASS_DICT[args.dataset]
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")


    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info(f"{strategy.__class__.__name__} : {num_workers}")
    logger.info(f"GLOBAL BATCH SIZE : {args.batch_size}")


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.dataset, args.classes, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = set_model(args.backbone, args.dataset, args.classes)
        if args.snapshot:
            model.load_weights(args.snapshot)
            logger.info(f"Load weights at {args.snapshot}")

        model.compile(
            loss=args.loss,
            optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc1'),
                tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc5')],
            xe_loss=tf.keras.losses.categorical_crossentropy,
            cls_loss=tf.keras.losses.KLD,
            cls_lambda=args.loss_weight,
            temperature=args.temperature,
            num_workers=num_workers,
            run_eagerly=True)


    ##########################
    # Generator
    ##########################
    train_generator = DataLoader(
        loss=args.loss,
        mode='train', 
        datalist=trainset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=True).dataloader()

    val_generator = DataLoader(
        loss='crossentropy',
        mode='val', 
        datalist=valset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=False).dataloader()


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,)