def run_train(args): # check parameterss check_train_paramenters(args) pretrained_net = args.net[:-1] if args.net.endswith("/") else args.net pretrained_net = os.path.split(pretrained_net)[1] check_net(pretrained_net) # Learning Rate initial_learning_rate, learning_rate_decay_factor, learning_rate_decay_epochs = lr_string_parse( args.lr) # Epochs to train n_training_epochs = args.n_training_epochs # Batch size batch_size = args.batch_size # model tracing net_path, initial_epoch = net_path_parse(args.net) if net_path.split("/")[-3].startswith( "fine_tuned_chalearnlap") and args.resume: print('Resuming finetuning {} from epoch {}...'.format( "fine_tuned_chalearnlap".upper(), initial_epoch)) else: print('Starting finetuning from epoch {}...'.format(initial_epoch)) # use of GPU and model loading model, input_shape = load_model_multiple_gpu( net_path.format(epoch=initial_epoch), args.selected_gpu) # # Freeze all layers but classification # for layer in model.layers[:-1]: # layer.trainable = False # for layer in model.layers: # if layer.trainable: # print(layer) # All layers trainable for layer in model.layers: layer.trainable = True model.summary() # model compiling if args.weight_decay: weight_decay = args.weight_decay # 0.0005 for layer in model.layers: if isinstance(layer, keras.layers.Conv2D) and not isinstance( layer, keras.layers.DepthwiseConv2D) or isinstance( layer, keras.layers.Dense): layer.add_loss( keras.regularizers.l2(weight_decay)(layer.kernel)) if hasattr(layer, 'bias_regularizer') and layer.use_bias: layer.add_loss(keras.regularizers.l2(weight_decay)(layer.bias)) optimizer = keras.optimizers.sgd(momentum=0.9) if args.momentum else 'sgd' loss = keras.losses.mean_squared_error accuracy_metrics = [keras.metrics.mean_squared_error] model.compile(loss=loss, optimizer=optimizer, metrics=accuracy_metrics) out_path, logdir = output_path_generation(args) # Augmentation loading if args.cutout: from cropout_test import CropoutAugmentation custom_augmentation = CropoutAugmentation() elif args.augmentation == 'default': from dataset_tools import DefaultAugmentation custom_augmentation = DefaultAugmentation() elif args.augmentation == 'vggface2': from dataset_tools import VGGFace2Augmentation custom_augmentation = VGGFace2Augmentation() else: custom_augmentation = None # load dataset print("Loading datasets...") print("Input shape:", input_shape, type(input_shape)) print("Preprocessing:", args.preprocessing) print("Augmentation:", custom_augmentation) dataset_training = Dataset("train", target_shape=input_shape, augment=False, preprocessing=args.preprocessing, custom_augmentation=custom_augmentation) dataset_validation = Dataset("val", target_shape=input_shape, augment=False, preprocessing=args.preprocessing) # select train initial epoch train_initial_epoch = initial_epoch if args.resume else 0 # Training print("Training out path", out_path) print("Training parameters:") for p, v in args.items(): print(p, ":", v) lr_sched = step_decay_schedule(initial_lr=initial_learning_rate, decay_factor=learning_rate_decay_factor, step_size=learning_rate_decay_epochs) monitor = 'val_mean_squared_error' checkpoint = keras.callbacks.ModelCheckpoint(out_path, verbose=1, save_best_only=True, monitor=monitor) tbCallBack = keras.callbacks.TensorBoard(log_dir=logdir, write_graph=True, write_images=True) # TODO patience, min_delta # early_stopping = keras.callbacks.EarlyStopping(monitor='val_mean_squared_error', patience=15, min_delta=0.5, verbose=1) callbacks_list = [lr_sched, checkpoint, tbCallBack] #, early_stopping] model.fit_generator( generator=dataset_training.get_generator(batch_size), validation_data=dataset_validation.get_generator(batch_size), verbose=1, callbacks=callbacks_list, epochs=n_training_epochs, workers=8, initial_epoch=train_initial_epoch)
max_ep = 0 max_c = None for c in os.listdir(d): epoch_num = re.search(ep_re, c) if epoch_num is not None: epoch_num = int(epoch_num.groups(1)[0]) if epoch_num > max_ep: max_ep = epoch_num max_c = c return max_ep, max_c # AUGMENTATION if args.cutout: from cropout_test import CropoutAugmentation custom_augmentation = CropoutAugmentation() elif args.augmentation == 'autoaugment-rafdb': from autoaug_test import MyAutoAugmentation from autoaugment.rafdb_policies import rafdb_policies custom_augmentation = MyAutoAugmentation(rafdb_policies) elif args.augmentation == 'default': from dataset_tools import DefaultAugmentation custom_augmentation = DefaultAugmentation() elif args.augmentation == 'vggface2': from dataset_tools import VGGFace2Augmentation custom_augmentation = VGGFace2Augmentation() else: custom_augmentation = None if args.mode.startswith('train'): print("TRAINING %s" % dirnm)