Ejemplo n.º 1
0
    loss = lossobj(masks, preds)

    valid_loss(loss)
    valid_accuracy(masks, preds)


tloss = []
vloss = []
taccuracy = []
vaccuracy = []
bestloss = sys.maxsize

# training through epochs
for epoch in range(args.epochs):
    train_loss.reset_states()
    train_accuracy.reset_states()

    if not args.train:
        valid_loss.reset_states()
        valid_accuracy.reset_states()

    for imgs, masks in train_data:
        trainStep(imgs, masks)

    for imgs, masks in valid_data:
        if not args.train:
            validStep(imgs, masks)
        else:
            trainStep(imgs, masks)

    template1 = "epoch[{0}/{1}] training: mean loss: {2}, accuracy: {3}"
Ejemplo n.º 2
0
def train(args):
    # config_tf2(args['configuration']['xla'])
    # Create log, checkpoint and export directories
    checkpoint_dir, log_dir, export_dir = create_env_directories(
        args, get_experiment_name(args))

    train_weight_dataset = dataloader.get_dataset(
        args['dataloader'],
        transformation_list=args['dataloader']['train_list'],
        num_classes=args["num_classes"],
        split='train_weights')
    train_arch_dataset = dataloader.get_dataset(
        args['dataloader'],
        transformation_list=args['dataloader']['train_list'],
        num_classes=args["num_classes"],
        split='train_arch')
    val_dataset = dataloader.get_dataset(
        args['dataloader'],
        transformation_list=args['dataloader']['val_list'],
        num_classes=args["num_classes"],
        split='validation')

    setup_mp(args)

    # define model, optimizer and checkpoint callback
    model = model_name_to_class[args['model_name']](
        args['framework'],
        input_shape=args['input_size'],
        label_dim=args['num_classes']).model
    model.summary()
    alchemy_api.send_model_info(model, args['server'])
    weight_opt = get_optimizer(args['optimizer'])
    arch_opt = get_optimizer(args['arch_optimizer_param'])
    model_checkpoint_cb, latest_epoch = init_custom_checkpoint_callbacks(
        {'model': model}, checkpoint_dir)

    weights, arch_params = split_trainable_weights(model)
    temperature_decay_fn = exponential_decay(
        args['temperature']['init_value'], args['temperature']['decay_steps'],
        args['temperature']['decay_rate'])

    lr_decay_fn = CosineDecay(
        args['optimizer']['lr'],
        alpha=args["optimizer"]["lr_decay_strategy"]["lr_params"]["alpha"],
        total_epochs=args['num_epochs'])

    loss_fn = CategoricalCrossentropy()
    accuracy_metric = CategoricalAccuracy()
    loss_metric = Mean()
    val_accuracy_metric = CategoricalAccuracy()
    val_loss_metric = Mean()

    train_log_dir = os.path.join(args['log_dir'], 'train')
    val_log_dir = os.path.join(args['log_dir'], 'validation')
    train_summary_writer = tf.summary.create_file_writer(train_log_dir)
    val_summary_writer = tf.summary.create_file_writer(val_log_dir)

    @tf.function
    def train_step(x_batch, y_batch):
        with tf.GradientTape() as tape:
            y_hat = model(x_batch, training=True)
            loss = loss_fn(y_batch, y_hat)

        accuracy_metric.update_state(y_batch, y_hat)
        loss_metric.update_state(loss)
        grads = tape.gradient(loss, weights)
        weight_opt.apply_gradients(zip(grads, weights))

    @tf.function
    def train_step_arch(x_batch, y_batch):
        with tf.GradientTape() as tape:
            y_hat = model(x_batch, training=False)
            loss = loss_fn(y_batch, y_hat)

        accuracy_metric.update_state(y_batch, y_hat)
        loss_metric.update_state(loss)
        grads = tape.gradient(loss, arch_params)
        arch_opt.apply_gradients(zip(grads, arch_params))

    @tf.function
    def evaluation_step(x_batch, y_batch):
        y_hat = model(x_batch, training=False)
        loss = loss_fn(y_batch, y_hat)

        val_accuracy_metric.update_state(y_batch, y_hat)
        val_loss_metric.update_state(loss)

    for epoch in range(latest_epoch, args['num_epochs']):
        print(f'Epoch: {epoch}/{args["num_epochs"]}')

        weight_opt.learning_rate = lr_decay_fn(epoch)

        # Updating the weight parameters using a subset of the training data
        for step, (x_batch, y_batch) in tqdm.tqdm(
                enumerate(train_weight_dataset, start=1)):
            train_step(x_batch, y_batch)

        # Evaluate the model on validation subset
        for x_batch, y_batch in val_dataset:
            evaluation_step(x_batch, y_batch)

        train_accuracy = accuracy_metric.result()
        train_loss = loss_metric.result()
        val_accuracy = val_accuracy_metric.result()
        val_loss = val_loss_metric.result()

        template = f'Weights updated, Epoch {epoch}, Train Loss: {float(train_loss)}, Train Accuracy: ' \
            f'{float(train_accuracy)}, Val Loss: {float(val_loss)}, Val Accuracy: {float(val_accuracy)}, ' \
            f'lr: {float(weight_opt.learning_rate)}'
        print(template)

        new_temperature = temperature_decay_fn(epoch)

        with train_summary_writer.as_default():
            tf.summary.scalar('loss', train_loss, step=epoch)
            tf.summary.scalar('accuracy', train_accuracy, step=epoch)
            tf.summary.scalar('temperature', new_temperature, step=epoch)

        with val_summary_writer.as_default():
            tf.summary.scalar('loss', val_loss, step=epoch)
            tf.summary.scalar('accuracy', val_accuracy, step=epoch)

        # Resetting metrices for reuse
        accuracy_metric.reset_states()
        loss_metric.reset_states()
        val_accuracy_metric.reset_states()
        val_loss_metric.reset_states()

        if epoch >= 10:
            # Updating the architectural parameters on another subset
            for step, (x_batch, y_batch) in tqdm.tqdm(
                    enumerate(train_arch_dataset, start=1)):
                train_step_arch(x_batch, y_batch)

            # Evaluate the model on validation subset
            for x_batch, y_batch in val_dataset:
                evaluation_step(x_batch, y_batch)

            train_accuracy = accuracy_metric.result()
            train_loss = loss_metric.result()
            val_accuracy = val_accuracy_metric.result()
            val_loss = val_loss_metric.result()

            template = f'Arch params updated, Epoch {epoch}, Train Loss: {float(train_loss)}, Train Accuracy: ' \
                f'{float(train_accuracy)}, Val Loss: {float(val_loss)}, Val Accuracy: {float(val_accuracy)}'
            print(template)
            with train_summary_writer.as_default():
                tf.summary.scalar('loss_after_arch_params_update',
                                  train_loss,
                                  step=epoch)
                tf.summary.scalar('accuracy_after_arch_params_update',
                                  train_accuracy,
                                  step=epoch)

            with val_summary_writer.as_default():
                tf.summary.scalar('loss_after_arch_params_update',
                                  val_loss,
                                  step=epoch)
                tf.summary.scalar('accuracy_after_arch_params_update',
                                  val_accuracy,
                                  step=epoch)

            # Resetting metrices for reuse
            accuracy_metric.reset_states()
            loss_metric.reset_states()
            val_accuracy_metric.reset_states()
            val_loss_metric.reset_states()

        define_temperature(new_temperature)

    print("Training Completed!!")

    print("Architecture params: ")
    print(arch_params)
    post_training_analysis(model, args['exported_architecture'])