예제 #1
0
def run_train(args):
    # check parameterss
    check_train_paramenters(args)
    pretrained_net = args.net[:-1] if args.net.endswith("/") else args.net
    pretrained_net = os.path.split(pretrained_net)[1]
    check_net(pretrained_net)

    # Learning Rate
    initial_learning_rate, learning_rate_decay_factor, learning_rate_decay_epochs = lr_string_parse(
        args.lr)

    # Epochs to train
    n_training_epochs = args.n_training_epochs

    # Batch size
    batch_size = args.batch_size

    # model tracing
    net_path, initial_epoch = net_path_parse(args.net)
    if net_path.split("/")[-3].startswith(
            "fine_tuned_chalearnlap") and args.resume:
        print('Resuming finetuning {} from epoch {}...'.format(
            "fine_tuned_chalearnlap".upper(), initial_epoch))
    else:
        print('Starting finetuning from epoch {}...'.format(initial_epoch))

    # use of GPU and model loading
    model, input_shape = load_model_multiple_gpu(
        net_path.format(epoch=initial_epoch), args.selected_gpu)

    # # Freeze all layers but classification
    # for layer in model.layers[:-1]:
    #     layer.trainable = False
    # for layer in model.layers:
    #     if layer.trainable:
    #         print(layer)

    # All layers trainable
    for layer in model.layers:
        layer.trainable = True
    model.summary()

    # model compiling
    if args.weight_decay:
        weight_decay = args.weight_decay  # 0.0005
        for layer in model.layers:
            if isinstance(layer, keras.layers.Conv2D) and not isinstance(
                    layer, keras.layers.DepthwiseConv2D) or isinstance(
                        layer, keras.layers.Dense):
                layer.add_loss(
                    keras.regularizers.l2(weight_decay)(layer.kernel))
            if hasattr(layer, 'bias_regularizer') and layer.use_bias:
                layer.add_loss(keras.regularizers.l2(weight_decay)(layer.bias))
    optimizer = keras.optimizers.sgd(momentum=0.9) if args.momentum else 'sgd'
    loss = keras.losses.mean_squared_error
    accuracy_metrics = [keras.metrics.mean_squared_error]
    model.compile(loss=loss, optimizer=optimizer, metrics=accuracy_metrics)

    out_path, logdir = output_path_generation(args)

    # Augmentation loading
    if args.cutout:
        from cropout_test import CropoutAugmentation
        custom_augmentation = CropoutAugmentation()
    elif args.augmentation == 'default':
        from dataset_tools import DefaultAugmentation
        custom_augmentation = DefaultAugmentation()
    elif args.augmentation == 'vggface2':
        from dataset_tools import VGGFace2Augmentation
        custom_augmentation = VGGFace2Augmentation()
    else:
        custom_augmentation = None

    # load dataset
    print("Loading datasets...")
    print("Input shape:", input_shape, type(input_shape))
    print("Preprocessing:", args.preprocessing)
    print("Augmentation:", custom_augmentation)
    dataset_training = Dataset("train",
                               target_shape=input_shape,
                               augment=False,
                               preprocessing=args.preprocessing,
                               custom_augmentation=custom_augmentation)
    dataset_validation = Dataset("val",
                                 target_shape=input_shape,
                                 augment=False,
                                 preprocessing=args.preprocessing)

    # select train initial epoch
    train_initial_epoch = initial_epoch if args.resume else 0

    # Training
    print("Training out path", out_path)
    print("Training parameters:")
    for p, v in args.items():
        print(p, ":", v)
    lr_sched = step_decay_schedule(initial_lr=initial_learning_rate,
                                   decay_factor=learning_rate_decay_factor,
                                   step_size=learning_rate_decay_epochs)
    monitor = 'val_mean_squared_error'
    checkpoint = keras.callbacks.ModelCheckpoint(out_path,
                                                 verbose=1,
                                                 save_best_only=True,
                                                 monitor=monitor)
    tbCallBack = keras.callbacks.TensorBoard(log_dir=logdir,
                                             write_graph=True,
                                             write_images=True)
    # TODO patience, min_delta
    # early_stopping = keras.callbacks.EarlyStopping(monitor='val_mean_squared_error', patience=15, min_delta=0.5, verbose=1)
    callbacks_list = [lr_sched, checkpoint, tbCallBack]  #, early_stopping]

    model.fit_generator(
        generator=dataset_training.get_generator(batch_size),
        validation_data=dataset_validation.get_generator(batch_size),
        verbose=1,
        callbacks=callbacks_list,
        epochs=n_training_epochs,
        workers=8,
        initial_epoch=train_initial_epoch)
예제 #2
0
    max_ep = 0
    max_c = None
    for c in os.listdir(d):
        epoch_num = re.search(ep_re, c)
        if epoch_num is not None:
            epoch_num = int(epoch_num.groups(1)[0])
            if epoch_num > max_ep:
                max_ep = epoch_num
                max_c = c
    return max_ep, max_c


# AUGMENTATION
if args.cutout:
    from cropout_test import CropoutAugmentation
    custom_augmentation = CropoutAugmentation()
elif args.augmentation == 'autoaugment-rafdb':
    from autoaug_test import MyAutoAugmentation
    from autoaugment.rafdb_policies import rafdb_policies
    custom_augmentation = MyAutoAugmentation(rafdb_policies)
elif args.augmentation == 'default':
    from dataset_tools import DefaultAugmentation
    custom_augmentation = DefaultAugmentation()
elif args.augmentation == 'vggface2':
    from dataset_tools import VGGFace2Augmentation
    custom_augmentation = VGGFace2Augmentation()
else:
    custom_augmentation = None

if args.mode.startswith('train'):
    print("TRAINING %s" % dirnm)