def create_callbacks(args, metrics): if args.snapshot is None: if args.checkpoint or args.history or args.tensorboard: flag = True while flag: try: os.makedirs( os.path.join(args.result_path, args.dataset, args.stamp)) flag = False except: args.stamp = create_stamp() yaml.dump(vars(args), open( os.path.join(args.result_path, args.dataset, args.stamp, "model_desc.yml"), "w"), default_flow_style=False) if args.checkpoint: os.makedirs(os.path.join( args.result_path, '{}/{}/checkpoint'.format(args.dataset, args.stamp)), exist_ok=True) if args.history: os.makedirs(os.path.join( args.result_path, '{}/{}/history'.format(args.dataset, args.stamp)), exist_ok=True) csvlogger = pd.DataFrame(columns=['epoch'] + list(metrics.keys())) if os.path.isfile( os.path.join( args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp))): csvlogger = pd.read_csv( os.path.join( args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp))) else: csvlogger.to_csv(os.path.join( args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)), index=False) else: csvlogger = None if args.tensorboard: train_writer = tf.summary.create_file_writer( os.path.join(args.result_path, args.dataset, args.stamp, 'logs/train')) val_writer = tf.summary.create_file_writer( os.path.join(args.result_path, args.dataset, args.stamp, 'logs/val')) else: train_writer = val_writer = None return csvlogger, train_writer, val_writer
def main(): set_seed() args = get_arguments() if args.task == 'pretext': if args.dataset == 'imagenet': args.lr = 0.5 * float(args.batch_size / 256) elif args.dataset == 'cifar10': args.lr = 0.03 * float(args.batch_size / 256) else: if args.dataset == 'imagenet' and args.freeze: args.lr = 30. * float(args.batch_size / 256) else: # args.dataset == 'cifar10': args.lr = 1.8 * float(args.batch_size / 256) args, initial_epoch = search_same(args) if initial_epoch == -1: # training was already finished! return elif initial_epoch == 0: # first training or training with snapshot args.stamp = create_stamp() get_session(args) logger = get_logger("MyLogger") for k, v in vars(args).items(): logger.info("{} : {}".format(k, v)) ########################## # Strategy ########################## if len(args.gpus.split(',')) > 1: # strategy = tf.distribute.experimental.CentralStorageStrategy() strategy = tf.distribute.MirroredStrategy() else: strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") num_workers = strategy.num_replicas_in_sync assert args.batch_size % num_workers == 0 logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers)) logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size // num_workers)) ########################## # Training ########################## if args.task == 'pretext': train_pretext(args, logger, initial_epoch, strategy, num_workers) else: train_lincls(args, logger, initial_epoch, strategy, num_workers)
def main(): set_seed() args = get_arguments() args.lr = args.lr or 1. * args.batch_size / 256 args, initial_epoch = search_same(args) if initial_epoch == -1: # training was already finished! return elif initial_epoch == 0: # first training or training with snapshot args.stamp = create_stamp() get_session(args) logger = get_logger("MyLogger") for k, v in vars(args).items(): logger.info("{} : {}".format(k, v)) ########################## # Strategy ########################## if len(args.gpus.split(',')) > 1: strategy = tf.distribute.MirroredStrategy() else: strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") num_workers = strategy.num_replicas_in_sync assert args.batch_size % num_workers == 0 logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers)) logger.info("BATCH SIZE PER WORKER : {}".format(args.batch_size // num_workers)) ########################## # Training ########################## if args.task == 'pretext': train_pixpro(args, logger, initial_epoch, strategy, num_workers) else: raise NotImplementedError()
def main(): set_seed() args = get_arguments() args, initial_epoch = search_same(args) if initial_epoch == -1: # training was already finished! return elif initial_epoch == 0: # first training or training with snapshot args.stamp = create_stamp() get_session(args) logger = get_logger("MyLogger") for k, v in vars(args).items(): logger.info("{} : {}".format(k, v)) ########################## # Strategy ########################## if len(args.gpus.split(',')) > 1: strategy = tf.distribute.experimental.CentralStorageStrategy() else: strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") num_workers = strategy.num_replicas_in_sync assert args.batch_size % num_workers == 0 logger.info('{} : {}'.format(strategy.__class__.__name__, num_workers)) logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size)) ########################## # Training ########################## if args.task in ['v1', 'v2']: train_moco(args, logger, initial_epoch, strategy, num_workers) else: train_lincls(args, logger, initial_epoch, strategy, num_workers)
def create_callbacks(args, logger, initial_epoch): if not args.resume: if args.checkpoint or args.history or args.tensorboard: if os.path.isdir( f'{args.result_path}/{args.dataset}/{args.stamp}'): flag = input( f'\n{args.dataset}/{args.stamp} is already saved. ' 'Do you want new stamp? (y/n) ') if flag == 'y': args.stamp = create_stamp() initial_epoch = 0 logger.info(f'New stamp {args.stamp} will be created.') elif flag == 'n': return -1, initial_epoch else: logger.info(f'You must select \'y\' or \'n\'.') return -2, initial_epoch os.makedirs(f'{args.result_path}/{args.dataset}/{args.stamp}') yaml.dump( vars(args), open( f'{args.result_path}/{args.dataset}/{args.stamp}/model_desc.yml', 'w'), default_flow_style=False) else: logger.info(f'{args.stamp} is not created due to ' f'checkpoint - {args.checkpoint} | ' f'history - {args.history} | ' f'tensorboard - {args.tensorboard}') callbacks = [] if args.checkpoint: os.makedirs( f'{args.result_path}/{args.dataset}/{args.stamp}/checkpoint', exist_ok=True) callbacks.append( ModelCheckpoint(filepath=os.path.join( f'{args.result_path}/{args.dataset}/{args.stamp}/checkpoint', '{epoch:04d}_{val_loss:.4f}_{val_acc1:.4f}_{val_acc5:.4f}.h5'), monitor='val_acc1', mode='max', verbose=1, save_weights_only=True)) if args.history: os.makedirs(f'{args.result_path}/{args.dataset}/{args.stamp}/history', exist_ok=True) callbacks.append( CSVLogger( filename= f'{args.result_path}/{args.dataset}/{args.stamp}/history/epoch.csv', separator=',', append=True)) if args.tensorboard: callbacks.append( TensorBoard( log_dir=f'{args.result_path}/{args.dataset}/{args.stamp}/logs', histogram_freq=args.tb_histogram, write_graph=True, write_images=True, update_freq=args.tb_interval, profile_batch=2, )) if args.lr_scheduler: def scheduler(epoch): if epoch < 100: return 0.1 elif epoch < 150: return 0.01 else: return 0.001 callbacks.append(LearningRateScheduler(schedule=scheduler, verbose=1)) return callbacks, initial_epoch
def create_callbacks(args, logger, initial_epoch): if not args.resume: if args.checkpoint or args.history: if os.path.isdir(f'{args.result_path}/{args.task}/{args.stamp}'): flag = input(f'\n{args.task}/{args.stamp} is already saved. ' 'Do you want new stamp? (y/n) ') if flag == 'y': args.stamp = create_stamp() initial_epoch = 0 logger.info(f'New stamp {args.stamp} will be created.') elif flag == 'n': return -1, initial_epoch else: logger.info(f'You must select \'y\' or \'n\'.') return -2, initial_epoch os.makedirs(f'{args.result_path}/{args.task}/{args.stamp}') yaml.dump( vars(args), open( f'{args.result_path}/{args.task}/{args.stamp}/model_desc.yml', 'w'), default_flow_style=False) else: logger.info(f'{args.stamp} is not created due to ' f'checkpoint - {args.checkpoint} | ' f'history - {args.history} | ') callbacks = [] if args.checkpoint: if args.task == 'pretext': callbacks.append( ModelCheckpoint( filepath= f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/latest.h5', monitor='loss', mode='min', verbose=1, save_weights_only=True)) callbacks.append( ModelCheckpoint( filepath= f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/best.h5', monitor='loss', mode='min', verbose=1, save_weights_only=True, save_best_only=True)) else: callbacks.append( ModelCheckpoint( filepath= f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/latest.h5', monitor='val_acc1', mode='max', verbose=1, save_weights_only=True)) callbacks.append( ModelCheckpoint( filepath= f'{args.result_path}/{args.task}/{args.stamp}/checkpoint/best.h5', monitor='val_acc1', mode='max', verbose=1, save_weights_only=True, save_best_only=True)) if args.history: os.makedirs(f'{args.result_path}/{args.task}/{args.stamp}/history', exist_ok=True) if args.task == 'pretext': callbacks.append( CustomCSVLogger( filename= f'{args.result_path}/{args.task}/{args.stamp}/history/epoch.csv', separator=',', append=True)) else: callbacks.append( CSVLogger( filename= f'{args.result_path}/{args.task}/{args.stamp}/history/epoch.csv', separator=',', append=True)) return callbacks, initial_epoch
def main(args): args, initial_epoch = search_same(args) if initial_epoch == -1: # training was already finished! return elif initial_epoch == 0: # first training or training with snapshot args.stamp = create_stamp() get_session(args) logger = get_logger("MyLogger") for k, v in vars(args).items(): logger.info("{} : {}".format(k, v)) ########################## # Strategy ########################## # strategy = tf.distribute.MirroredStrategy() strategy = tf.distribute.experimental.CentralStorageStrategy() assert args.batch_size % strategy.num_replicas_in_sync == 0 logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync)) logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size)) logger.info("BATCH SIZE PER REPLICA : {}".format(args.batch_size // strategy.num_replicas_in_sync)) ########################## # Dataset ########################## trainset, valset = set_dataset(args) steps_per_epoch = args.steps or len(trainset) // args.batch_size validation_steps = len(valset) // args.batch_size logger.info("TOTAL STEPS OF DATASET FOR TRAINING") logger.info("========== trainset ==========") logger.info(" --> {}".format(len(trainset))) logger.info(" --> {}".format(steps_per_epoch)) logger.info("=========== valset ===========") logger.info(" --> {}".format(len(valset))) logger.info(" --> {}".format(validation_steps)) ########################## # Model & Metric & Generator ########################## # metrics metrics = { 'loss' : tf.keras.metrics.Mean('loss', dtype=tf.float32), 'val_loss': tf.keras.metrics.Mean('val_loss', dtype=tf.float32), } if args.loss == 'crossentropy': metrics.update({ 'acc1' : tf.keras.metrics.TopKCategoricalAccuracy(1, 'acc1', dtype=tf.float32), 'acc5' : tf.keras.metrics.TopKCategoricalAccuracy(5, 'acc5', dtype=tf.float32), 'val_acc1' : tf.keras.metrics.TopKCategoricalAccuracy(1, 'val_acc1', dtype=tf.float32), 'val_acc5' : tf.keras.metrics.TopKCategoricalAccuracy(5, 'val_acc5', dtype=tf.float32)}) with strategy.scope(): model = create_model(args, logger) if args.summary: model.summary() return # optimizer lr_scheduler = OptionalLearningRateSchedule(args, steps_per_epoch, initial_epoch) if args.optimizer == 'sgd': optimizer = tf.keras.optimizers.SGD(lr_scheduler, momentum=.9, decay=.0001) elif args.optimizer == 'rmsprop': optimizer = tf.keras.optimizers.RMSprop(lr_scheduler) elif args.optimizer == 'adam': optimizer = tf.keras.optimizers.Adam(lr_scheduler) # loss & generator if args.loss == 'supcon': criterion = supervised_contrastive(args, args.batch_size // strategy.num_replicas_in_sync) train_generator = dataloader_supcon(args, trainset, 'train', args.batch_size) val_generator = dataloader_supcon(args, valset, 'train', args.batch_size, shuffle=False) elif args.loss == 'crossentropy': criterion = crossentropy(args) train_generator = dataloader(args, trainset, 'train', args.batch_size) val_generator = dataloader(args, valset, 'val', args.batch_size, shuffle=False) else: raise ValueError() train_generator = strategy.experimental_distribute_dataset(train_generator) val_generator = strategy.experimental_distribute_dataset(val_generator) csvlogger, train_writer, val_writer = create_callbacks(args, metrics) logger.info("Build Model & Metrics") ########################## # READY Train ########################## train_iterator = iter(train_generator) val_iterator = iter(val_generator) # @tf.function def do_step(iterator, mode): def get_loss(inputs, labels, training=True): logits = tf.cast(model(inputs, training=training), tf.float32) loss = criterion(labels, logits) loss_mean = tf.nn.compute_average_loss(loss, global_batch_size=args.batch_size) return logits, loss, loss_mean def step_fn(from_iterator): if args.loss == 'supcon': (img1, img2), labels = from_iterator inputs = tf.concat([img1, img2], axis=0) else: inputs, labels = from_iterator if mode == 'train': with tf.GradientTape() as tape: logits, loss, loss_mean = get_loss(inputs, labels) grads = tape.gradient(loss_mean, model.trainable_variables) optimizer.apply_gradients(list(zip(grads, model.trainable_variables))) else: logits, loss, loss_mean = get_loss(inputs, labels, training=False) if args.loss == 'crossentropy': metrics['acc' if mode == 'train' else 'val_acc'].update_state(labels, logits) return loss loss_per_replica = strategy.run(step_fn, args=(next(iterator),)) loss_mean = strategy.reduce(tf.distribute.ReduceOp.MEAN, loss_per_replica, axis=0) metrics['loss' if mode == 'train' else 'val_loss'].update_state(loss_mean) ########################## # Train ########################## for epoch in range(initial_epoch, args.epochs): print('\nEpoch {}/{}'.format(epoch+1, args.epochs)) print('Learning Rate : {}'.format(optimizer.learning_rate(optimizer.iterations))) # train print('Train') progBar_train = tf.keras.utils.Progbar(steps_per_epoch, stateful_metrics=metrics.keys()) for step in range(steps_per_epoch): do_step(train_iterator, 'train') progBar_train.update(step, values=[(k, v.result()) for k, v in metrics.items() if not 'val' in k]) if args.tensorboard and args.tb_interval > 0: if (epoch*steps_per_epoch+step) % args.tb_interval == 0: with train_writer.as_default(): for k, v in metrics.items(): if not 'val' in k: tf.summary.scalar(k, v.result(), step=epoch*steps_per_epoch+step) if args.tensorboard and args.tb_interval == 0: with train_writer.as_default(): for k, v in metrics.items(): if not 'val' in k: tf.summary.scalar(k, v.result(), step=epoch) # val print('\n\nValidation') progBar_val = tf.keras.utils.Progbar(validation_steps, stateful_metrics=metrics.keys()) for step in range(validation_steps): do_step(val_iterator, 'val') progBar_val.update(step, values=[(k, v.result()) for k, v in metrics.items() if 'val' in k]) # logs logs = {k: v.result().numpy() for k, v in metrics.items()} logs['epoch'] = epoch + 1 if args.checkpoint: if args.loss == 'supcon': ckpt_path = '{:04d}_{:.4f}.h5'.format(epoch+1, logs['val_loss']) else: ckpt_path = '{:04d}_{:.4f}_{:.4f}.h5'.format(epoch+1, logs['val_acc'], logs['val_loss']) model.save_weights( os.path.join( args.result_path, '{}/{}/checkpoint'.format(args.dataset, args.stamp), ckpt_path)) print('\nSaved at {}'.format( os.path.join( args.result_path, '{}/{}/checkpoint'.format(args.dataset, args.stamp), ckpt_path))) if args.history: csvlogger = csvlogger.append(logs, ignore_index=True) csvlogger.to_csv(os.path.join(args.result_path, '{}/{}/history/epoch.csv'.format(args.dataset, args.stamp)), index=False) if args.tensorboard: with train_writer.as_default(): tf.summary.scalar('loss', metrics['loss'].result(), step=epoch) if args.loss == 'crossentropy': tf.summary.scalar('acc', metrics['acc'].result(), step=epoch) with val_writer.as_default(): tf.summary.scalar('val_loss', metrics['val_loss'].result(), step=epoch) if args.loss == 'crossentropy': tf.summary.scalar('val_acc', metrics['val_acc'].result(), step=epoch) for k, v in metrics.items(): v.reset_states()
def main(args=None): set_seed() args, initial_epoch = search_same(args) if initial_epoch == -1: # training was already finished! return elif initial_epoch == 0: # first training or training with snapshot args.stamp = create_stamp() get_session(args) logger = get_logger("MyLogger") for k, v in vars(args).items(): logger.info("{} : {}".format(k, v)) ########################## # Strategy ########################## strategy = tf.distribute.MirroredStrategy() num_workers = strategy.num_replicas_in_sync assert args.batch_size % strategy.num_replicas_in_sync == 0 logger.info('{} : {}'.format(strategy.__class__.__name__, strategy.num_replicas_in_sync)) logger.info("GLOBAL BATCH SIZE : {}".format(args.batch_size)) ########################## # Dataset ########################## trainset, valset = set_dataset(args.data_path, args.dataset) if args.steps is not None: steps_per_epoch = args.steps elif args.dataset == 'cifar10': steps_per_epoch = 50000 // args.batch_size validation_steps = 10000 // args.batch_size elif args.dataset == 'svhn': steps_per_epoch = 73257 // args.batch_size validation_steps = 26032 // args.batch_size elif args.dataset == 'imagenet': steps_per_epoch = len(trainset) // args.batch_size validation_steps = len(valset) // args.batch_size logger.info("TOTAL STEPS OF DATASET FOR TRAINING") logger.info("========== trainset ==========") logger.info(" --> {}".format(len(trainset))) logger.info(" --> {}".format(steps_per_epoch)) logger.info("=========== valset ===========") logger.info(" --> {}".format(len(valset))) logger.info(" --> {}".format(validation_steps)) ########################## # Model & Metric & Generator ########################## metrics = { 'acc' : tf.keras.metrics.CategoricalAccuracy('acc', dtype=tf.float32), 'val_acc' : tf.keras.metrics.CategoricalAccuracy('val_acc', dtype=tf.float32), 'loss' : tf.keras.metrics.Mean('loss', dtype=tf.float32), 'val_loss' : tf.keras.metrics.Mean('val_loss', dtype=tf.float32), 'total_loss': tf.keras.metrics.Mean('total_loss', dtype=tf.float32), 'unsup_loss': tf.keras.metrics.Mean('unsup_loss', dtype=tf.float32)} with strategy.scope(): model =
def main(): args = get_arguments() set_seed(args.seed) args.classes = CLASS_DICT[args.dataset] args, initial_epoch = search_same(args) if initial_epoch == -1: # training was already finished! return elif initial_epoch == 0: # first training or training with snapshot args.stamp = create_stamp() get_session(args) logger = get_logger("MyLogger") for k, v in vars(args).items(): logger.info(f"{k} : {v}") ########################## # Strategy ########################## if len(args.gpus.split(',')) > 1: strategy = tf.distribute.experimental.CentralStorageStrategy() else: strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") num_workers = strategy.num_replicas_in_sync assert args.batch_size % num_workers == 0 logger.info(f"{strategy.__class__.__name__} : {num_workers}") logger.info(f"GLOBAL BATCH SIZE : {args.batch_size}") ########################## # Dataset ########################## trainset, valset = set_dataset(args.dataset, args.classes, args.data_path) steps_per_epoch = args.steps or len(trainset) // args.batch_size validation_steps = len(valset) // args.batch_size logger.info("TOTAL STEPS OF DATASET FOR TRAINING") logger.info("========== TRAINSET ==========") logger.info(f" --> {len(trainset)}") logger.info(f" --> {steps_per_epoch}") logger.info("=========== VALSET ===========") logger.info(f" --> {len(valset)}") logger.info(f" --> {validation_steps}") ########################## # Model ########################## with strategy.scope(): model = set_model(args.backbone, args.dataset, args.classes) if args.snapshot: model.load_weights(args.snapshot) logger.info(f"Load weights at {args.snapshot}") model.compile( loss=args.loss, optimizer=tf.keras.optimizers.SGD(args.lr, momentum=.9), metrics=[ tf.keras.metrics.TopKCategoricalAccuracy(k=1, name='acc1'), tf.keras.metrics.TopKCategoricalAccuracy(k=5, name='acc5')], xe_loss=tf.keras.losses.categorical_crossentropy, cls_loss=tf.keras.losses.KLD, cls_lambda=args.loss_weight, temperature=args.temperature, num_workers=num_workers, run_eagerly=True) ########################## # Generator ########################## train_generator = DataLoader( loss=args.loss, mode='train', datalist=trainset, dataset=args.dataset, classes=args.classes, batch_size=args.batch_size, shuffle=True).dataloader() val_generator = DataLoader( loss='crossentropy', mode='val', datalist=valset, dataset=args.dataset, classes=args.classes, batch_size=args.batch_size, shuffle=False).dataloader() ########################## # Train ########################## callbacks, initial_epoch = create_callbacks(args, logger, initial_epoch) if callbacks == -1: logger.info('Check your model.') return elif callbacks == -2: return model.fit( train_generator, validation_data=val_generator, epochs=args.epochs, callbacks=callbacks, initial_epoch=initial_epoch, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps,)