コード例 #1
0
ファイル: main.py プロジェクト: stjordanis/BarlowTwins-TF
def train_lincls(args, logger, initial_epoch, strategy, num_workers):
    # assert args.snapshot is not None, 'pretrained weight is needed!'
    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.task, args.dataset, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model & Generator
    ##########################
    train_generator = DataLoader(args, args.task, 'train', trainset, args.batch_size, num_workers).dataloader
    val_generator = DataLoader(args, args.task, 'val', valset, args.batch_size, num_workers).dataloader
        
    with strategy.scope():
        backbone = SimSiam(args, logger)
        model = set_lincls(args, backbone.encoder)
        if args.resume and args.snapshot:
            model.load_weights(args.snapshot)
            logger.info('Load weights at {}'.format(args.snapshot))

        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        model.compile(
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            metrics=[tf.keras.metrics.SparseTopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
                     tf.keras.metrics.SparseTopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32)],
            loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, name='loss'),
            run_eagerly=False)


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps)
コード例 #2
0
ファイル: main.py プロジェクト: stjordanis/BarlowTwins-TF
def train_pretext(args, logger, initial_epoch, strategy, num_workers):
    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.task, args.dataset, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")


    ##########################
    # Model & Generator
    ##########################
    with strategy.scope():
        model = BarlowTwins(args, logger, num_workers=num_workers)
        if args.summary:
            model.build((None, args.img_size, args.img_size, 3))
            model.summary()
            return 
        
        # Load checkpoints
        if args.snapshot:
            model.build((None, args.img_size, args.img_size, 3))
            model.load_weights(args.snapshot)
            logger.info('Load weights at {}'.format(args.snapshot))

        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        model.compile(
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            loss=tf.keras.losses.cosine_similarity,
            run_eagerly=False)

    train_generator = DataLoader(args, args.task, 'train', trainset, args.batch_size, num_workers).dataloader


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,)
コード例 #3
0
ファイル: main.py プロジェクト: ymcidence/MoCo-TF
def train_moco(args, logger, initial_epoch, strategy, num_workers):
    ##########################
    # Dataset
    ##########################
    trainset = set_dataset(args.task, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    ##########################
    # Model & Generator
    ##########################
    with strategy.scope():
        model = MoCo(args, logger)

        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch,
                                                    initial_epoch)
        model.compile(
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            loss=tf.keras.losses.sparse_categorical_crossentropy,
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(1,
                                                         'acc1',
                                                         dtype=tf.float32),
                tf.keras.metrics.TopKCategoricalAccuracy(5,
                                                         'acc5',
                                                         dtype=tf.float32)
            ],
            num_workers=num_workers,
            run_eagerly=True)

    train_generator = DataLoader(args, 'train', trainset, args.batch_size,
                                 num_workers).dataloader

    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
    )
コード例 #4
0
def main():
    args = get_argument()
    assert args.model_name is not None, 'model_name must be set.'

    sys.path.append(args.baseline_path)
    from common import get_logger
    from common import get_session
    from callback_eager import OptionalLearningRateSchedule
    from callback_eager import create_callbacks

    logger = get_logger("MyLogger")

    args, initial_epoch = set_cfg(args, logger)
    if initial_epoch == -1:
        # training was already finished!
        return

    get_session(args)

    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args)

    ##########################
    # Model & Metric & Generator
    ##########################
    progress_desc_train = 'Train : Loss {:.4f} | Acc {:.4f}'
    progress_desc_val = 'Val : Loss {:.4f} | Acc {:.4f}'

    strategy = tf.distribute.MirroredStrategy()
    # strategy = tf.distribute.experimental.CentralStorageStrategy()
    global_batch_size = args.batch_size * strategy.num_replicas_in_sync

    steps_per_epoch = args.steps or len(trainset) // global_batch_size
    validation_steps = len(valset) // global_batch_size

    # lr scheduler
    lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch,
                                                initial_epoch)

    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            from tensorflow.keras.utils import plot_model
            plot_model(model,
                       to_file=os.path.join(args.src_path, 'model.png'),
                       show_shapes=True)
            model.summary(line_length=130)
            return

        # metrics
        metrics = {
            'loss':
            tf.keras.metrics.Mean('loss', dtype=tf.float32),
            'acc':
            tf.keras.metrics.CategoricalAccuracy('acc', dtype=tf.float32),
            'val_loss':
            tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
            'val_acc':
            tf.keras.metrics.CategoricalAccuracy('val_acc', dtype=tf.float32),
        }

        # optimizer
        if args.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(lr_scheduler,
                                                momentum=.9,
                                                decay=.00005)
        elif args.optimizer == 'rmsprop':
            optimizer = tf.keras.optimizers.RMSprop(lr_scheduler)
        elif args.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(lr_scheduler)
        else:
            raise ValueError()

        # loss
        if args.loss == 'crossentropy':
            criterion = crossentropy(args)
        else:
            raise ValueError()

        # generator
        if args.loss == 'crossentropy':
            train_generator = dataloader(args, trainset, 'train',
                                         global_batch_size)
            val_generator = dataloader(args,
                                       valset,
                                       'val',
                                       global_batch_size,
                                       shuffle=False)
        else:
            raise ValueError()

        train_generator = strategy.experimental_distribute_dataset(
            train_generator)
        val_generator = strategy.experimental_distribute_dataset(val_generator)

    path = os.path.join(args.result_path, args.dataset, args.model_name,
                        str(args.stamp))
    csvlogger, train_writer, val_writer = create_callbacks(args, metrics, path)
    logger.info("Build Model & Metrics")

    ##########################
    # Log Arguments & Settings
    ##########################
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    logger.info('{} : {}'.format(strategy.__class__.__name__,
                                 strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(global_batch_size))

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))

    ##########################
    # READY Train
    ##########################
    train_iterator = iter(train_generator)
    val_iterator = iter(val_generator)

    @tf.function
    def do_step(iterator, mode, loss_name, acc_name=None):
        def step_fn(from_iterator):
            inputs, labels = from_iterator
            if mode == 'train':
                # TODO : loss 계산 다시하기
                with tf.GradientTape() as tape:
                    logits = tf.cast(model(inputs, training=True), tf.float32)
                    loss = criterion(labels, logits)
                    loss = tf.reduce_sum(loss) * (1. / global_batch_size)

                grads = tape.gradient(loss, model.trainable_variables)
                optimizer.apply_gradients(
                    list(zip(grads, model.trainable_variables)))
            else:
                logits = tf.cast(model(inputs, training=False), tf.float32)
                loss = criterion(labels, logits)
                loss = tf.reduce_sum(loss) * (1. / global_batch_size)

            metrics[loss_name].update_state(loss)
            metrics[acc_name].update_state(labels, logits)

        strategy.run(step_fn, args=(next(iterator), ))
        # step_fn(next(iterator))

    def desc_update(pbar, desc, loss, acc=None):
        pbar.set_description(desc.format(loss.result(), acc.result()))

    ##########################
    # Train
    ##########################
    for epoch in range(initial_epoch, args.epochs):
        print('\nEpoch {}/{}'.format(epoch + 1, args.epochs))
        print('Learning Rate : {}'.format(
            optimizer.learning_rate(optimizer.iterations)))

        # train
        progressbar_train = tqdm.tqdm(tf.range(steps_per_epoch),
                                      desc=progress_desc_train.format(
                                          0, 0, 0, 0),
                                      leave=True)
        for step in progressbar_train:
            do_step(train_iterator, 'train', 'loss', 'acc')
            desc_update(progressbar_train, progress_desc_train,
                        metrics['loss'], metrics['acc'])
            progressbar_train.refresh()

        # eval
        progressbar_val = tqdm.tqdm(tf.range(validation_steps),
                                    desc=progress_desc_val.format(0, 0),
                                    leave=True)
        for step in progressbar_val:
            do_step(val_iterator, 'val', 'val_loss', 'val_acc')
            desc_update(progressbar_val, progress_desc_val,
                        metrics['val_loss'], metrics['val_acc'])
            progressbar_val.refresh()

        # logs
        logs = {k: v.result().numpy() for k, v in metrics.items()}
        logs['epoch'] = epoch

        if args.checkpoint:
            ckpt_path = '{:04d}_{:.4f}_{:.4f}.h5'.format(
                epoch + 1, logs['val_acc'], logs['val_loss'])
            model.save_weights(os.path.join(path, 'checkpoint', ckpt_path))
            print('\nSaved at {}'.format(
                os.path.join(path, 'checkpoint', ckpt_path)))

        if args.history:
            csvlogger = csvlogger.append(logs, ignore_index=True)
            csvlogger.to_csv(os.path.join(path, 'history/epoch.csv'),
                             index=False)

        if args.tensorboard:
            with train_writer.as_default():
                tf.summary.scalar('loss', metrics['loss'].result(), step=epoch)
                tf.summary.scalar('acc', metrics['acc'].result(), step=epoch)

            with val_writer.as_default():
                tf.summary.scalar('val_loss',
                                  metrics['val_loss'].result(),
                                  step=epoch)
                tf.summary.scalar('val_acc',
                                  metrics['val_acc'].result(),
                                  step=epoch)

        for k, v in metrics.items():
            v.reset_states()
コード例 #5
0
ファイル: main.py プロジェクト: liupengcnu/SupCL-TF
def main(args):
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))


    ##########################
    # Strategy
    ##########################
    # strategy = tf.distribute.MirroredStrategy()
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))
    logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size // strategy.num_replicas_in_sync))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))

    ##########################
    # Model & Metric & Generator
    ##########################
    # metrics
    metrics = {
        'loss'    :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss':   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
    }

    if args.loss == 'crossentropy':
        metrics.update({
            'acc1'      : tf.keras.metrics.TopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32),
            'acc5'      : tf.keras.metrics.TopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32),
            'val_acc1'  : tf.keras.metrics.TopKCategoricalAccuracy(1, 'val_acc1', dtype=tf.float32),
            'val_acc5'  : tf.keras.metrics.TopKCategoricalAccuracy(5, 'val_acc5', dtype=tf.float32)})

    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            model.summary()
            return

        # optimizer
        lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        if args.optimizer == 'sgd':
            optimizer = tf.keras.optimizers.SGD(lr_scheduler, momentum=.9, decay=.0001)
        elif args.optimizer == 'rmsprop':
            optimizer = tf.keras.optimizers.RMSprop(lr_scheduler)
        elif args.optimizer == 'adam':
            optimizer = tf.keras.optimizers.Adam(lr_scheduler)

        # loss & generator
        if args.loss == 'supcon':
            criterion = supervised_contrastive(args, args.batch_size // strategy.num_replicas_in_sync)
            train_generator = dataloader_supcon(args, trainset, 'train', args.batch_size)
            val_generator = dataloader_supcon(args, valset, 'train', args.batch_size, shuffle=False)
        elif args.loss == 'crossentropy':
            criterion = crossentropy(args)
            train_generator = dataloader(args, trainset, 'train', args.batch_size)
            val_generator = dataloader(args, valset, 'val', args.batch_size, shuffle=False)
        else:
            raise ValueError()    
        
        train_generator = strategy.experimental_distribute_dataset(train_generator)
        val_generator = strategy.experimental_distribute_dataset(val_generator)

    csvlogger, train_writer, val_writer = create_callbacks(args, metrics)
    logger.info("Build Model & Metrics")

    ##########################
    # READY Train
    ##########################
    train_iterator = iter(train_generator)
    val_iterator = iter(val_generator)
        
    # @tf.function
    def do_step(iterator, mode):
        def get_loss(inputs, labels, training=True):
            logits = tf.cast(model(inputs, training=training), tf.float32)
            loss = criterion(labels, logits)
            loss_mean = tf.nn.compute_average_loss(loss, global_batch_size=args.batch_size)
            return logits, loss, loss_mean

        def step_fn(from_iterator):
            if args.loss == 'supcon':
                (img1, img2), labels = from_iterator
                inputs = tf.concat([img1, img2], axis=0)
            else:
                inputs, labels = from_iterator
            
            if mode == 'train':
                with tf.GradientTape() as tape:
                    logits, loss, loss_mean = get_loss(inputs, labels)

                grads = tape.gradient(loss_mean, model.trainable_variables)
                optimizer.apply_gradients(list(zip(grads, model.trainable_variables)))
            else:
                logits, loss, loss_mean = get_loss(inputs, labels, training=False)

            if args.loss == 'crossentropy':
                metrics['acc' if mode == 'train' else 'val_acc'].update_state(labels, logits)

            return loss

        loss_per_replica = strategy.run(step_fn, args=(next(iterator),))
        loss_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_per_replica, axis=0)
        metrics['loss' if mode == 'train' else 'val_loss'].update_state(loss_mean)
        

    ##########################
    # Train
    ##########################
    for epoch in range(initial_epoch, args.epochs):
        print('\nEpoch {}/{}'.format(epoch+1, args.epochs))
        print('Learning Rate : {}'.format(optimizer.learning_rate(optimizer.iterations)))

        # train
        print('Train')
        progBar_train = tf.keras.utils.Progbar(steps_per_epoch, stateful_metrics=metrics.keys())
        for step in range(steps_per_epoch):
            do_step(train_iterator, 'train')
            progBar_train.update(step, values=[(k, v.result()) for k, v in metrics.items() if not 'val' in k])

            if args.tensorboard and args.tb_interval > 0:
                if (epoch*steps_per_epoch+step) % args.tb_interval == 0:
                    with train_writer.as_default():
                        for k, v in metrics.items():
                            if not 'val' in k:
                                tf.summary.scalar(k, v.result(), step=epoch*steps_per_epoch+step)

        if args.tensorboard and args.tb_interval == 0:
            with train_writer.as_default():
                for k, v in metrics.items():
                    if not 'val' in k:
                        tf.summary.scalar(k, v.result(), step=epoch)

        # val
        print('\n\nValidation')
        progBar_val = tf.keras.utils.Progbar(validation_steps, stateful_metrics=metrics.keys())
        for step in range(validation_steps):
            do_step(val_iterator, 'val')
            progBar_val.update(step, values=[(k, v.result()) for k, v in metrics.items() if 'val' in k])
    
        # logs
        logs = {k: v.result().numpy() for k, v in metrics.items()}
        logs['epoch'] = epoch + 1

        if args.checkpoint:
            if args.loss == 'supcon':
                ckpt_path = '{:04d}_{:.4f}.h5'.format(epoch+1, logs['val_loss'])
            else:
                ckpt_path = '{:04d}_{:.4f}_{:.4f}.h5'.format(epoch+1, logs['val_acc'], logs['val_loss'])

            model.save_weights(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path))

            print('\nSaved at {}'.format(
                os.path.join(
                    args.result_path, 
                    '{}/{}/checkpoint'.format(args.dataset, args.stamp),
                    ckpt_path)))

        if args.history:
            csvlogger = csvlogger.append(logs, ignore_index=True)
            csvlogger.to_csv(os.path.join(args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)), index=False)

        if args.tensorboard:
            with train_writer.as_default():
                tf.summary.scalar('loss', metrics['loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('acc', metrics['acc'].result(), step=epoch)

            with val_writer.as_default():
                tf.summary.scalar('val_loss', metrics['val_loss'].result(), step=epoch)
                if args.loss == 'crossentropy':
                    tf.summary.scalar('val_acc', metrics['val_acc'].result(), step=epoch)
        
        for k, v in metrics.items():
            v.reset_states()
コード例 #6
0
def main():
    temp_args = get_arguments()
    assert temp_args.snapshot is not None, 'snapshot must be selected!'
    set_seed()

    args = argparse.ArgumentParser().parse_args(args=[])
    tmp = yaml.full_load(
        open(
            f'{temp_args.result_path}/'
            f'{temp_args.dataset}/'
            f'{temp_args.stamp}/'
            'model_desc.yml', 'r'))

    for k, v in tmp.items():
        setattr(args, k, v)

    args.snapshot = temp_args.snapshot
    args.src_path = temp_args.src_path
    args.data_path = temp_args.data_path
    args.result_path = temp_args.result_path
    args.gpus = temp_args.gpus
    args.batch_size = 1

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")

    ##########################
    # Dataset
    ##########################
    _, valset = set_dataset(args.dataset, args.classes, args.data_path)
    validation_steps = len(valset)

    logger.info("TOTAL STEPS OF DATASET FOR EVALUATION")
    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {validation_steps}")

    ##########################
    # Model & Generator
    ##########################
    model = set_model(args.backbone, args.dataset, args.classes)
    model.load_weights(args.snapshot)
    logger.info(f"Load weights at {args.snapshot}")

    model.compile(loss=args.loss,
                  batch_size=args.batch_size,
                  optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
                  metrics=[
                      tf.keras.metrics.TopKCategoricalAccuracy(k=1,
                                                               name='acc1'),
                      tf.keras.metrics.TopKCategoricalAccuracy(k=5,
                                                               name='acc5')
                  ],
                  xe_loss=tf.keras.losses.categorical_crossentropy,
                  cls_loss=tf.keras.losses.KLD,
                  cls_lambda=args.loss_weight,
                  temperature=args.temperature)

    val_generator = DataLoader(loss='crossentropy',
                               mode='val',
                               datalist=valset,
                               dataset=args.dataset,
                               classes=args.classes,
                               batch_size=args.batch_size,
                               shuffle=False).dataloader()

    ##########################
    # Evaluation
    ##########################
    print(
        model.evaluate(val_generator, steps=validation_steps,
                       return_dict=True))
コード例 #7
0
ファイル: main.py プロジェクト: PaperCodeReview/RegNet-TF
def main():
    set_seed()
    args = get_arguments()
    assert args.model_name is not None, 'model_name must be set.'

    logger = get_logger("MyLogger")
    args, initial_epoch = set_cfg(args, logger)
    if initial_epoch == -1:
        # training was already finished!
        return

    get_session(args)
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))


    ##########################
    # Strategy
    ##########################
    # strategy = tf.distribute.MirroredStrategy()
    strategy = tf.distribute.experimental.CentralStorageStrategy()
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))


    ##########################
    # Generator
    ##########################
    trainset, valset = set_dataset(args)
    train_generator = dataloader(args, trainset, 'train')
    val_generator = dataloader(args, valset, 'val', shuffle=False)
    
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size
    
    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = create_model(args, logger)
        if args.summary:
            from tensorflow.keras.utils import plot_model
            plot_model(model, to_file=os.path.join(args.src_path, 'model.png'), show_shapes=True)
            model.summary(line_length=130)
            return

        # optimizer
        scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch)
        optimizer = tf.keras.optimizers.SGD(scheduler, momentum=.9, decay=.00005)

        model.compile(
            optimizer=optimizer,
            loss=tf.keras.losses.categorical_crossentropy,
            metrics=['acc']
        )


    ##########################
    # Callbacks
    ##########################
    callbacks = create_callbacks(
        args, 
        path=os.path.join(args.result_path, args.dataset, args.model_name, str(args.stamp)))
    logger.info("Build callbacks!")


    ##########################
    # Train
    ##########################
    model.fit(
        x=train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        validation_data=val_generator,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,
        initial_epoch=initial_epoch,
        verbose=1,
    )
コード例 #8
0
def train_pixpro(args, logger, initial_epoch, strategy, num_workers):
    ##########################
    # Dataset
    ##########################
    trainset = set_dataset(args.task, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    ##########################
    # Model & Generator
    ##########################
    train_generator = DataLoader(args, 'train', trainset,
                                 args.batch_size).dataloader
    with strategy.scope():
        model = PixPro(logger,
                       norm='bn' if num_workers == 1 else 'syncbn',
                       channel=256,
                       gamma=args.gamma,
                       num_layers=args.num_layers,
                       snapshot=args.snapshot)

        if args.summary:
            model.summary()
            return

        lr_scheduler = OptionalLearningRateSchedule(
            lr=args.lr,
            lr_mode=args.lr_mode,
            lr_interval=args.lr_interval,
            lr_value=args.lr_value,
            total_epochs=args.epochs,
            steps_per_epoch=steps_per_epoch,
            initial_epoch=initial_epoch)

        model.compile(
            # TODO : apply LARS
            optimizer=tf.keras.optimizers.SGD(lr_scheduler, momentum=.9),
            batch_size=args.batch_size,
            num_workers=num_workers,
            run_eagerly=None)

    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
    )
コード例 #9
0
def main(args=None):
    set_seed()
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info("{} : {}".format(k, v))

    ##########################
    # Strategy
    ##########################
    strategy = tf.distribute.MirroredStrategy()
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % strategy.num_replicas_in_sync == 0

    logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync))
    logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size))


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.data_path, args.dataset)
    if args.steps is not None:
        steps_per_epoch = args.steps
    elif args.dataset == 'cifar10':
        steps_per_epoch = 50000 // args.batch_size
        validation_steps = 10000 // args.batch_size
    elif args.dataset == 'svhn':
        steps_per_epoch = 73257 // args.batch_size
        validation_steps = 26032 // args.batch_size
    elif args.dataset == 'imagenet':
        steps_per_epoch = len(trainset) // args.batch_size
        validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== trainset ==========")
    logger.info("    --> {}".format(len(trainset)))
    logger.info("    --> {}".format(steps_per_epoch))

    logger.info("=========== valset ===========")
    logger.info("    --> {}".format(len(valset)))
    logger.info("    --> {}".format(validation_steps))


    ##########################
    # Model & Metric & Generator
    ##########################
    metrics = {
        'acc'       :   tf.keras.metrics.CategoricalAccuracy('acc', dtype=tf.float32),
        'val_acc'   :   tf.keras.metrics.CategoricalAccuracy('val_acc', dtype=tf.float32),
        'loss'      :   tf.keras.metrics.Mean('loss', dtype=tf.float32),
        'val_loss'  :   tf.keras.metrics.Mean('val_loss', dtype=tf.float32),
        'total_loss':   tf.keras.metrics.Mean('total_loss', dtype=tf.float32),
        'unsup_loss':   tf.keras.metrics.Mean('unsup_loss', dtype=tf.float32)}
    
    with strategy.scope():
        model = 
コード例 #10
0
ファイル: main.py プロジェクト: PaperCodeReview/CSKD-TF
def main():
    args = get_arguments()
    set_seed(args.seed)
    args.classes = CLASS_DICT[args.dataset]
    args, initial_epoch = search_same(args)
    if initial_epoch == -1:
        # training was already finished!
        return

    elif initial_epoch == 0:
        # first training or training with snapshot
        args.stamp = create_stamp()

    get_session(args)
    logger = get_logger("MyLogger")
    for k, v in vars(args).items():
        logger.info(f"{k} : {v}")


    ##########################
    # Strategy
    ##########################
    if len(args.gpus.split(',')) > 1:
        strategy = tf.distribute.experimental.CentralStorageStrategy()
    else:
        strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0")
    
    num_workers = strategy.num_replicas_in_sync
    assert args.batch_size % num_workers == 0

    logger.info(f"{strategy.__class__.__name__} : {num_workers}")
    logger.info(f"GLOBAL BATCH SIZE : {args.batch_size}")


    ##########################
    # Dataset
    ##########################
    trainset, valset = set_dataset(args.dataset, args.classes, args.data_path)
    steps_per_epoch = args.steps or len(trainset) // args.batch_size
    validation_steps = len(valset) // args.batch_size

    logger.info("TOTAL STEPS OF DATASET FOR TRAINING")
    logger.info("========== TRAINSET ==========")
    logger.info(f"    --> {len(trainset)}")
    logger.info(f"    --> {steps_per_epoch}")

    logger.info("=========== VALSET ===========")
    logger.info(f"    --> {len(valset)}")
    logger.info(f"    --> {validation_steps}")


    ##########################
    # Model
    ##########################
    with strategy.scope():
        model = set_model(args.backbone, args.dataset, args.classes)
        if args.snapshot:
            model.load_weights(args.snapshot)
            logger.info(f"Load weights at {args.snapshot}")

        model.compile(
            loss=args.loss,
            optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9),
            metrics=[
                tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc1'),
                tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc5')],
            xe_loss=tf.keras.losses.categorical_crossentropy,
            cls_loss=tf.keras.losses.KLD,
            cls_lambda=args.loss_weight,
            temperature=args.temperature,
            num_workers=num_workers,
            run_eagerly=True)


    ##########################
    # Generator
    ##########################
    train_generator = DataLoader(
        loss=args.loss,
        mode='train', 
        datalist=trainset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=True).dataloader()

    val_generator = DataLoader(
        loss='crossentropy',
        mode='val', 
        datalist=valset, 
        dataset=args.dataset, 
        classes=args.classes,
        batch_size=args.batch_size, 
        shuffle=False).dataloader()


    ##########################
    # Train
    ##########################
    callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch)
    if callbacks == -1:
        logger.info('Check your model.')
        return
    elif callbacks == -2:
        return

    model.fit(
        train_generator,
        validation_data=val_generator,
        epochs=args.epochs,
        callbacks=callbacks,
        initial_epoch=initial_epoch,
        steps_per_epoch=steps_per_epoch,
        validation_steps=validation_steps,)