def train_simclr(np_train, testing_flag, window_size, batch_size,
                 non_testing_simclr_epochs, transformation_indices,
                 initial_learning_rate, non_testing_linear_eval_epochs,
                 aggregate, temperature):
    decay_steps = 1000

    if testing_flag:
        epochs = 1
    else:
        epochs = non_testing_simclr_epochs

    input_shape = (window_size, 6)

    transform_funcs_vectorised = [
        hchs_transformations.noise_transform_vectorized,
        hchs_transformations.scaling_transform_vectorized,
        hchs_transformations.negate_transform_vectorized,
        hchs_transformations.time_flip_transform_vectorized,
        hchs_transformations.time_segment_permutation_transform_improved,
        hchs_transformations.time_warp_transform_low_cost,
        hchs_transformations.channel_shuffle_transform_vectorized
    ]
    transform_funcs_names = [
        'noised', 'scaled', 'negated', 'time_flipped', 'permuted',
        'time_warped', 'channel_shuffled'
    ]

    tf.keras.backend.set_floatx('float32')

    # lr_decayed_fn = tf.keras.experimental.CosineDecay(initial_learning_rate=initial_learning_rate, decay_steps=decay_steps)
    # optimizer = tf.keras.optimizers.SGD(lr_decayed_fn)

    learning_rate = 0.3 * (batch_size / 256)
    optimizer = LARSOptimizer(learning_rate, weight_decay=0.000001)
    transformation_function = simclr_utitlities.generate_combined_transform_function(
        transform_funcs_vectorised, indices=transformation_indices)

    base_model = simclr_models.create_base_model(input_shape,
                                                 model_name="base_model")
    simclr_model = simclr_models.attach_simclr_head(base_model)
    # simclr_model.summary()
    trained_simclr_model, epoch_losses = simclr_utitlities.simclr_train_model(
        simclr_model,
        np_train,
        optimizer,
        batch_size,
        transformation_function,
        temperature=temperature,
        epochs=epochs,
        is_trasnform_function_vectorized=True,
        verbose=1)

    return trained_simclr_model
Esempio n. 2
0
def train_simclr(testing_flag, np_train, transformation_indices=[0,1], lr=0.01, batch_size=128, epoch_number=100):
    decay_steps = 1000
    if testing_flag:
        epochs = 1
    else:
        epochs = epoch_number

    temperature = 0.1

    input_shape = (np_train.shape[1], np_train.shape[2])

    transform_funcs_vectorised = [
        chapman_transformations.noise_transform_vectorized, 
        chapman_transformations.scaling_transform_vectorized, 
        chapman_transformations.negate_transform_vectorized, 
        chapman_transformations.time_flip_transform_vectorized, 
        chapman_transformations.time_segment_permutation_transform_improved, 
        chapman_transformations.time_warp_transform_low_cost, 
        chapman_transformations.channel_shuffle_transform_vectorized
    ]

    transform_funcs_names = ['noised', 'scaled', 'negated', 'time_flipped', 'permuted', 'time_warped', 'channel_shuffled']
    tf.keras.backend.set_floatx('float32')

    # lr_decayed_fn = tf.keras.experimental.CosineDecay(initial_learning_rate=lr, decay_steps=decay_steps)
    # optimizer = tf.keras.optimizers.SGD(lr_decayed_fn)
    learning_rate = 0.3 * (batch_size / 256)
    lr_decayed_fn = tf.keras.experimental.CosineDecay(initial_learning_rate=learning_rate, decay_steps=decay_steps)
    optimizer = LARSOptimizer(learning_rate, weight_decay=0.000001)

    transformation_function = simclr_utitlities.generate_combined_transform_function(transform_funcs_vectorised, indices=transformation_indices)

    base_model = simclr_models.create_base_model(input_shape, model_name="base_model")
    simclr_model = simclr_models.attach_simclr_head(base_model)
    simclr_model.summary()

    trained_simclr_model, epoch_losses = simclr_utitlities.simclr_train_model(simclr_model, np_train, optimizer, batch_size, transformation_function, temperature=temperature, epochs=epochs, is_trasnform_function_vectorized=True, verbose=1)
    save_model_directory = os.path.join("save_models", "simclr")
    try:
        if not os.path.exists(save_model_directory):
            os.makedirs(save_model_directory)
    except OSError as err:
        print(err)

    start_time = datetime.datetime.now()
    start_time_str = start_time.strftime("%Y%m%d-%H%M%S")
    save_path = os.path.join(save_model_directory, f'{start_time_str}-{testing_flag}-{transformation_indices}-{batch_size}-{epoch_number}.h5')
    trained_simclr_model.save(save_path)
    return trained_simclr_model, epoch_losses
Esempio n. 3
0
def get_optimizer(learning_rate):
    """Returns an optimizer."""
    if FLAGS.optimizer == 'momentum':
        optimizer = tf.train.MomentumOptimizer(learning_rate,
                                               FLAGS.momentum,
                                               use_nesterov=True)
    elif FLAGS.optimizer == 'adam':
        optimizer = tf.train.AdamOptimizer(learning_rate)
    elif FLAGS.optimizer == 'lars':
        optimizer = LARSOptimizer(learning_rate,
                                  momentum=FLAGS.momentum,
                                  weight_decay=FLAGS.weight_decay,
                                  exclude_from_weight_decay=[
                                      'batch_normalization', 'bias',
                                      'head_supervised'
                                  ])
    else:
        raise ValueError('Unknown optimizer {}'.format(FLAGS.optimizer))

    if FLAGS.use_tpu:
        optimizer = tf.tpu.CrossShardOptimizer(optimizer)
    return optimizer
Esempio n. 4
0
    def model_fn(features, labels, mode, params=None):
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode
        if FLAGS.train_mode == 'pretrain':
            raise ValueError('Pretraining not possible in multihead config,'
                             'Set train_mode to finetune')
        elif FLAGS.train_mode == 'finetune':
            # Base network forward pass.
            with tf.variable_scope('base_model'):
                if FLAGS.train_mode == 'finetune'\
                 and FLAGS.fine_tune_after_block >= 4:
                    # Finetune just supervised (linear) head will not
                    # update BN stats.
                    model_train_mode = False
                else:
                    model_train_mode = is_training
                hiddens = model(features, is_training=model_train_mode)

            logits_animal = model_util.supervised_head(
                hiddens,
                animal_num_classes,
                is_training,
                name='head_supervised_animals')
            obj_lib.add_supervised_loss(labels=labels['animal_label'],
                                        logits=logits_animal,
                                        weights=labels['animal_mask'])

            logits_plant = model_util.supervised_head(
                hiddens,
                plant_num_classes,
                is_training,
                name='head_supervised_plants')
            obj_lib.add_supervised_loss(labels=labels['plant_label'],
                                        logits=logits_plant,
                                        weights=labels['plant_mask'])
            model_util.add_weight_decay(adjust_per_optimizer=True)
            loss = tf.losses.get_total_loss()

            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty'
            tf.logging.info(
                '===============Variables to train (begin)===============')
            tf.logging.info(variables_to_train)
            tf.logging.info(
                '================Variables to train (end)================')

            learning_rate = model_util.learning_rate_schedule(
                FLAGS.learning_rate, num_train_examples)

            if is_training:
                if FLAGS.train_summary_steps > 0:
                    summary_writer = tf2.summary.create_file_writer(
                        FLAGS.model_dir)
                    with tf.control_dependencies([summary_writer.init()]):
                        with summary_writer.as_default():
                            should_record = tf.math.equal(
                                tf.math.floormod(tf.train.get_global_step(),
                                                 FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            label_acc_animal = tf.equal(
                                tf.argmax(labels['animal_label'], 1),
                                tf.argmax(logits_animal, axis=1))
                            label_acc_plant = tf.equal(
                                tf.argmax(labels['plant_label'], 1),
                                tf.argmax(logits_plant, axis=1))

                            label_acc = tf.math.logical_or(
                                label_acc_animal, label_acc_plant)
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())
                if FLAGS.optimizer == 'momentum':
                    optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                           FLAGS.momentum,
                                                           use_nesterov=True)
                elif FLAGS.optimizer == 'adam':
                    optimizer = tf.train.AdamOptimizer(learning_rate)
                elif FLAGS.optimizer == 'lars':
                    optimizer = LARSOptimizer(learning_rate,
                                              momentum=FLAGS.momentum,
                                              weight_decay=FLAGS.weight_decay,
                                              exclude_from_weight_decay=[
                                                  'batch_normalization', 'bias'
                                              ])
                else:
                    raise ValueError('Unknown optimizer {}'.format(
                        FLAGS.optimizer))

                if FLAGS.use_tpu:
                    optimizer = tf.tpu.CrossShardOptimizer(optimizer)
                control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

                if FLAGS.train_summary_steps > 0:
                    control_deps.extend(tf.summary.all_v2_summary_ops())
                with tf.control_dependencies(control_deps):
                    train_op = optimizer.minimize(
                        loss,
                        global_step=tf.train.get_or_create_global_step(),
                        var_list=variables_to_train)

                if FLAGS.checkpoint:

                    def scaffold_fn():
                        """
                        Scaffold function to restore non-logits vars from
                        checkpoint.
                        """
                        tf.train.init_from_checkpoint(
                            FLAGS.checkpoint, {
                                v.op.name: v.op.name
                                for v in tf.global_variables(
                                    FLAGS.variable_schema)
                            })

                        if FLAGS.zero_init_logits_layer:
                            # Init op that initializes output layer parameters
                            # to zeros.
                            output_layer_parameters = [
                                var for var in tf.trainable_variables()
                                if var.name.startswith('head_supervised')
                            ]
                            tf.logging.info(
                                'Initializing output layer parameters %s to 0',
                                [x.op.name for x in output_layer_parameters])
                            with tf.control_dependencies(
                                [tf.global_variables_initializer()]):
                                init_op = tf.group([
                                    tf.assign(x, tf.zeros_like(x))
                                    for x in output_layer_parameters
                                ])
                            return tf.train.Scaffold(init_op=init_op)
                        else:
                            return tf.train.Scaffold()
                else:
                    scaffold_fn = None

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    train_op=train_op,
                    loss=loss,
                    scaffold_fn=scaffold_fn)

            elif mode == tf.estimator.ModeKeys.PREDICT:
                _, animal_top_5 = tf.nn.top_k(logits_animal, k=5)
                _, plant_top_5 = tf.nn.top_k(logits_plant, k=5)

                predictions = {
                    'label': tf.argmax(labels['label'], 1),
                    'animal_label': tf.argmax(labels['animal_label'], 1),
                    'plant_label': tf.argmax(labels['plant_label'], 1),
                    'animal_top_5': animal_top_5,
                    'plant_top_5': plant_top_5,
                }

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode, predictions=predictions)
            elif mode == tf.estimator.ModeKeys.EVAL:

                def metric_fn(logits_animal, logits_plant, labels_animal,
                              labels_plant, mask_animal, mask_plant, **kws):
                    metrics[
                        'label_animal_top_1_accuracy'] = tf.metrics.accuracy(
                            tf.argmax(labels_animal, 1),
                            tf.argmax(logits_animal, axis=1),
                            weights=mask_animal)
                    metrics[
                        'label_animal_top_5_accuracy'] = tf.metrics.recall_at_k(
                            tf.argmax(labels_animal, 1),
                            logits_animal,
                            k=5,
                            weights=mask_animal)
                    metrics[
                        'label_plant_top_1_accuracy'] = tf.metrics.accuracy(
                            tf.argmax(labels_plant, 1),
                            tf.argmax(logits_plant, axis=1),
                            weights=mask_plant)
                    metrics[
                        'label_plant_top_5_accuracy'] = tf.metrics.recall_at_k(
                            tf.argmax(labels_plant, 1),
                            logits_plant,
                            k=5,
                            weights=mask_plant)

                metrics = {
                    'logits_animal': logits_animal,
                    'logits_plant': logits_plant,
                    'labels_animal': labels['animal_label'],
                    'labels_plant': labels['plant_label'],
                    'mask_animal': labels['animal_mask'],
                    'mask_plant': labels['plant_mask'],
                }

                return tf.estimator.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=(metric_fn, metrics),
                    scaffold_fn=None)
            else:
                print('Invalid mode.')
Esempio n. 5
0
    def model_fn(features, labels, mode, params=None):
        """Build model and optimizer."""
        is_training = mode == tf.estimator.ModeKeys.TRAIN

        # Check training mode.
        if FLAGS.train_mode == 'pretrain':
            num_transforms = 2
            if FLAGS.fine_tune_after_block > -1:
                raise ValueError(
                    'Does not support layer freezing during pretraining,'
                    'should set fine_tune_after_block<=-1 for safety.')
        elif FLAGS.train_mode == 'finetune':
            num_transforms = 1
        else:
            raise ValueError('Unknown train_mode {}'.format(FLAGS.train_mode))

        # Split channels, and optionally apply extra batched augmentation.
        features_list = tf.split(features,
                                 num_or_size_splits=num_transforms,
                                 axis=-1)
        if FLAGS.use_blur and is_training and FLAGS.train_mode == 'pretrain':
            features_list = data_util.batch_random_blur(
                features_list, FLAGS.image_size, FLAGS.image_size)
        features = tf.concat(features_list,
                             0)  # (num_transforms * bsz, h, w, c)

        # Base network forward pass.
        with tf.variable_scope('base_model'):
            if FLAGS.train_mode == 'finetune' and FLAGS.fine_tune_after_block >= 4:
                # Finetune just supervised (linear) head will not update BN stats.
                model_train_mode = False
            else:
                # Pretrain or finetuen anything else will update BN stats.
                model_train_mode = is_training
            hiddens = model(features, is_training=model_train_mode)

        # Add head and loss.
        if FLAGS.train_mode == 'pretrain':
            tpu_context = params['context'] if 'context' in params else None
            hiddens_proj = model_util.projection_head(hiddens, is_training)
            contrast_loss, logits_con, labels_con = obj_lib.add_contrastive_loss(
                hiddens_proj,
                hidden_norm=FLAGS.hidden_norm,
                temperature=FLAGS.temperature,
                tpu_context=tpu_context if is_training else None)
            logits_sup = tf.zeros([params['batch_size'], num_classes])
        else:
            contrast_loss = tf.zeros([])
            logits_con = tf.zeros([params['batch_size'], 10])
            labels_con = tf.zeros([params['batch_size'], 10])
            logits_sup = model_util.supervised_head(hiddens, num_classes,
                                                    is_training)
            obj_lib.add_supervised_loss(labels=labels['labels'],
                                        logits=logits_sup,
                                        weights=labels['mask'])

        # Add weight decay to loss, for non-LARS optimizers.
        model_util.add_weight_decay(adjust_per_optimizer=True)
        loss = tf.losses.get_total_loss()

        if FLAGS.train_mode == 'pretrain':
            variables_to_train = tf.trainable_variables()
        else:
            collection_prefix = 'trainable_variables_inblock_'
            variables_to_train = []
            for j in range(FLAGS.fine_tune_after_block + 1, 6):
                variables_to_train += tf.get_collection(collection_prefix +
                                                        str(j))
            assert variables_to_train, 'variables_to_train shouldn\'t be empty!'

        tf.logging.info(
            '===============Variables to train (begin)===============')
        tf.logging.info(variables_to_train)
        tf.logging.info(
            '================Variables to train (end)================')

        learning_rate = model_util.learning_rate_schedule(
            FLAGS.learning_rate, num_train_examples)

        if is_training:
            if FLAGS.train_summary_steps > 0:
                # Compute stats for the summary.
                prob_con = tf.nn.softmax(logits_con)
                entropy_con = -tf.reduce_mean(
                    tf.reduce_sum(prob_con * tf.math.log(prob_con + 1e-8), -1))

                summary_writer = tf2.summary.create_file_writer(
                    FLAGS.model_dir)
                # TODO(iamtingchen): remove this control_dependencies in the future.
                with tf.control_dependencies([summary_writer.init()]):
                    with summary_writer.as_default():
                        should_record = tf.math.equal(
                            tf.math.floormod(tf.train.get_global_step(),
                                             FLAGS.train_summary_steps), 0)
                        with tf2.summary.record_if(should_record):
                            contrast_acc = tf.equal(
                                tf.argmax(labels_con, 1),
                                tf.argmax(logits_con, axis=1))
                            contrast_acc = tf.reduce_mean(
                                tf.cast(contrast_acc, tf.float32))
                            label_acc = tf.equal(
                                tf.argmax(labels['labels'], 1),
                                tf.argmax(logits_sup, axis=1))
                            label_acc = tf.reduce_mean(
                                tf.cast(label_acc, tf.float32))
                            tf2.summary.scalar('train_contrast_loss',
                                               contrast_loss,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_contrast_acc',
                                               contrast_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('train_label_accuracy',
                                               label_acc,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('contrast_entropy',
                                               entropy_con,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('learning_rate',
                                               learning_rate,
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_mean',
                                               tf.reduce_mean(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_max',
                                               tf.reduce_max(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('input_min',
                                               tf.reduce_min(features),
                                               step=tf.train.get_global_step())
                            tf2.summary.scalar('num_labels',
                                               tf.reduce_mean(
                                                   tf.reduce_sum(
                                                       labels['labels'], -1)),
                                               step=tf.train.get_global_step())

            if FLAGS.optimizer == 'momentum':
                optimizer = tf.train.MomentumOptimizer(learning_rate,
                                                       FLAGS.momentum,
                                                       use_nesterov=True)
            elif FLAGS.optimizer == 'adam':
                optimizer = tf.train.AdamOptimizer(learning_rate)
            elif FLAGS.optimizer == 'lars':
                optimizer = LARSOptimizer(
                    learning_rate,
                    momentum=FLAGS.momentum,
                    weight_decay=FLAGS.weight_decay,
                    exclude_from_weight_decay=['batch_normalization', 'bias'])
            else:
                raise ValueError('Unknown optimizer {}'.format(
                    FLAGS.optimizer))

            if FLAGS.use_tpu:
                optimizer = tf.tpu.CrossShardOptimizer(optimizer)

            control_deps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            if FLAGS.train_summary_steps > 0:
                control_deps.extend(tf.summary.all_v2_summary_ops())
            with tf.control_dependencies(control_deps):
                train_op = optimizer.minimize(
                    loss,
                    global_step=tf.train.get_or_create_global_step(),
                    var_list=variables_to_train)

            if FLAGS.checkpoint:

                def scaffold_fn():
                    """Scaffold function to restore non-logits vars from checkpoint."""
                    for v in tf.global_variables(FLAGS.variable_schema):
                        print(v.op.name)
                    tf.train.init_from_checkpoint(
                        FLAGS.checkpoint, {
                            v.op.name: v.op.name
                            for v in tf.global_variables(FLAGS.variable_schema)
                        })

                    if FLAGS.zero_init_logits_layer:
                        # Init op that initializes output layer parameters to zeros.
                        output_layer_parameters = [
                            var for var in tf.trainable_variables()
                            if var.name.startswith('head_supervised')
                        ]
                        tf.logging.info(
                            'Initializing output layer parameters %s to zero',
                            [x.op.name for x in output_layer_parameters])
                        with tf.control_dependencies(
                            [tf.global_variables_initializer()]):
                            init_op = tf.group([
                                tf.assign(x, tf.zeros_like(x))
                                for x in output_layer_parameters
                            ])
                        return tf.train.Scaffold(init_op=init_op)
                    else:
                        return tf.train.Scaffold()
            else:
                scaffold_fn = None

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     train_op=train_op,
                                                     loss=loss,
                                                     scaffold_fn=scaffold_fn)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            _, top_5 = tf.nn.top_k(logits_sup, k=5)
            predictions = {
                'label': tf.argmax(labels['labels'], 1),
                'top_5': top_5,
            }
            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     predictions=predictions)
        else:

            def metric_fn(logits_sup, labels_sup, logits_con, labels_con, mask,
                          **kws):
                """Inner metric function."""
                metrics = {
                    k: tf.metrics.mean(v, weights=mask)
                    for k, v in kws.items()
                }
                metrics['label_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_sup, 1),
                    tf.argmax(logits_sup, axis=1),
                    weights=mask)
                metrics['label_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_sup, 1), logits_sup, k=5, weights=mask)
                metrics['contrastive_top_1_accuracy'] = tf.metrics.accuracy(
                    tf.argmax(labels_con, 1),
                    tf.argmax(logits_con, axis=1),
                    weights=mask)
                metrics['contrastive_top_5_accuracy'] = tf.metrics.recall_at_k(
                    tf.argmax(labels_con, 1), logits_con, k=5, weights=mask)

                metrics[
                    'mean_class_accuracy'] = tf.metrics.mean_per_class_accuracy(
                        tf.argmax(labels_sup, 1),
                        tf.argmax(logits_sup, axis=1),
                        num_classes,
                        weights=mask,
                        name='mca')

                running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES,
                                                 scope="mca")
                metrics['mean_class_accuracy_total'] = running_vars[0]
                metrics['mean_class_accuracy_count'] = running_vars[1]

                return metrics

            metrics = {
                'logits_sup':
                logits_sup,
                'labels_sup':
                labels['labels'],
                'logits_con':
                logits_con,
                'labels_con':
                labels_con,
                'mask':
                labels['mask'],
                'contrast_loss':
                tf.fill((params['batch_size'], ), contrast_loss),
                'regularization_loss':
                tf.fill((params['batch_size'], ),
                        tf.losses.get_regularization_loss()),
            }

            return tf.estimator.tpu.TPUEstimatorSpec(mode=mode,
                                                     loss=loss,
                                                     eval_metrics=(metric_fn,
                                                                   metrics),
                                                     scaffold_fn=None)
Esempio n. 6
0
def main():
    args = parse_option()
    print(args)

    # check args
    if args.loss not in LOSS_NAMES:
        raise ValueError('Unsupported loss function type {}'.format(args.loss))

    if args.optimizer == 'adam':
        optimizer1 = tf.keras.optimizers.Adam(lr=args.lr_1)
    elif args.optimizer == 'lars':
        from lars_optimizer import LARSOptimizer
        # not compatible with tf2
        optimizer1 = LARSOptimizer(
            args.lr_1,
            exclude_from_weight_decay=['batch_normalization', 'bias'])
    elif args.optimizer == 'sgd':
        optimizer1 = tfa.optimizers.SGDW(learning_rate=args.lr_1,
                                         momentum=0.9,
                                         weight_decay=1e-4)
    optimizer2 = tf.keras.optimizers.Adam(lr=args.lr_2)

    model_name = '{}_model-bs_{}-lr_{}'.format(args.loss, args.batch_size_1,
                                               args.lr_1)

    # 0. Load data
    if args.data == 'mnist':
        mnist = tf.keras.datasets.mnist
    elif args.data == 'fashion_mnist':
        mnist = tf.keras.datasets.fashion_mnist
    print('Loading {} data...'.format(args.data))
    (_, y_train), (_, y_test) = mnist.load_data()
    # x_train, x_test = x_train / 255.0, x_test / 255.0
    # x_train = x_train.reshape(-1, 28*28).astype(np.float32)
    # x_test = x_test.reshape(-1, 28*28).astype(np.float32)
    (x_train, _), (x_test, _), _, _ = load_mnist()
    # print(x_train[0][0])
    print(x_train.shape, x_test.shape)

    # simulate low data regime for training
    # n_train = x_train.shape[0]
    # shuffle_idx = np.arange(n_train)
    # np.random.shuffle(shuffle_idx)

    # x_train = x_train[shuffle_idx][:args.n_data_train]
    # y_train = y_train[shuffle_idx][:args.n_data_train]
    # print('Training dataset shapes after slicing:')
    print(x_train.shape, y_train.shape)

    train_ds = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(5000).batch(args.batch_size_1)

    train_ds2 = tf.data.Dataset.from_tensor_slices(
        (x_train, y_train)).shuffle(5000).batch(args.batch_size_2)

    test_ds = tf.data.Dataset.from_tensor_slices(
        (x_test, y_test)).batch(args.batch_size_1)

    # 1. Stage 1: train encoder with multiclass N-pair loss
    encoder = Encoder(normalize=True, activation=args.activation)
    projector = Projector(args.projection_dim,
                          normalize=True,
                          activation=args.activation)

    if args.loss == 'max_margin':

        def loss_func(z, y):
            return losses.max_margin_contrastive_loss(z,
                                                      y,
                                                      margin=args.margin,
                                                      metric=args.metric)
    elif args.loss == 'npairs':
        loss_func = losses.multiclass_npairs_loss
    elif args.loss == 'sup_nt_xent':

        def loss_func(z, y):
            return losses.supervised_nt_xent_loss(
                z,
                y,
                temperature=args.temperature,
                base_temperature=args.base_temperature)
    elif args.loss.startswith('triplet'):
        triplet_kind = args.loss.split('-')[1]

        def loss_func(z, y):
            return losses.triplet_loss(z,
                                       y,
                                       kind=triplet_kind,
                                       margin=args.margin)

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    test_loss = tf.keras.metrics.Mean(name='test_loss')

    # tf.config.experimental_run_functions_eagerly(True)
    @tf.function
    # train step for the contrastive loss
    def train_step_stage1(x, y):
        '''
        x: data tensor, shape: (batch_size, data_dim)
        y: data labels, shape: (batch_size, )
        '''
        with tf.GradientTape() as tape:
            r = encoder(x, training=True)
            z = projector(r, training=True)
            # print("z", z, "y", y)
            loss = loss_func(z, y)

        gradients = tape.gradient(
            loss, encoder.trainable_variables + projector.trainable_variables)
        optimizer1.apply_gradients(
            zip(gradients,
                encoder.trainable_variables + projector.trainable_variables))
        train_loss(loss)

    @tf.function
    def test_step_stage1(x, y):
        r = encoder(x, training=False)
        z = projector(r, training=False)
        t_loss = loss_func(z, y)
        test_loss(t_loss)

    print('Stage 1 training ...')
    for epoch in range(args.epoch):
        # Reset the metrics at the start of the next epoch
        train_loss.reset_states()
        test_loss.reset_states()

        for x, y in train_ds:
            train_step_stage1(x, y)

        for x_te, y_te in test_ds:
            test_step_stage1(x_te, y_te)

        template = 'Epoch {}, Loss: {}, Test Loss: {}'
        # print(template.format(epoch + 1,
        #                       train_loss.result(),
        #                       test_loss.result()))

    if args.draw_figures:
        # projecting data with the trained encoder, projector
        x_tr_proj = projector(encoder(x_train))
        x_te_proj = projector(encoder(x_test))
        # convert tensor to np.array
        x_tr_proj = x_tr_proj.numpy()
        x_te_proj = x_te_proj.numpy()
        print(x_tr_proj.shape, x_te_proj.shape)

        # check learned embedding using PCA
        pca = PCA(n_components=2)
        pca.fit(x_tr_proj)
        x_te_proj_pca = pca.transform(x_te_proj)

        x_te_proj_pca_df = pd.DataFrame(x_te_proj_pca, columns=['PC1', 'PC2'])
        x_te_proj_pca_df['label'] = y_test
        # PCA scatter plot
        fig, ax = plt.subplots()
        ax = sns.scatterplot('PC1',
                             'PC2',
                             data=x_te_proj_pca_df,
                             palette='tab10',
                             hue='label',
                             linewidth=0,
                             alpha=0.6,
                             ax=ax)

        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
        ax.legend(loc='center left', bbox_to_anchor=(1, 0.5))
        title = 'Data: {}\nEmbedding: {}\nbatch size: {}; LR: {}'.format(
            args.data, LOSS_NAMES[args.loss], args.batch_size_1, args.lr_1)
        ax.set_title(title)
        fig.savefig('figs/PCA_plot_{}_{}_embed.png'.format(
            args.data, model_name))

        # density plot for PCA
        g = sns.jointplot('PC1', 'PC2', data=x_te_proj_pca_df, kind="hex")
        plt.subplots_adjust(top=0.95)
        g.fig.suptitle(title)

        g.savefig('figs/Joint_PCA_plot_{}_{}_embed.png'.format(
            args.data, model_name))

    # Stage 2: freeze the learned representations and then learn a classifier
    # on a linear layer using a softmax loss
    softmax = SoftmaxPred()

    train_loss = tf.keras.metrics.Mean(name='train_loss')
    train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_ACC')
    test_loss = tf.keras.metrics.Mean(name='test_loss')
    test_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='test_ACC')

    cce_loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)

    @tf.function
    # train step for the 2nd stage
    def train_step(model, x, y):
        '''
        x: data tensor, shape: (batch_size, data_dim)
        y: data labels, shape: (batch_size, )
        '''
        with tf.GradientTape() as tape:
            r = model.layers[0](x, training=False)
            y_preds = model.layers[1](r, training=True)
            loss = cce_loss_obj(y, y_preds)

        # freeze the encoder, only train the softmax layer
        gradients = tape.gradient(loss, model.layers[1].trainable_variables)
        optimizer2.apply_gradients(
            zip(gradients, model.layers[1].trainable_variables))
        train_loss(loss)
        train_acc(y, y_preds)

    @tf.function
    def test_step(x, y):
        r = encoder(x, training=False)
        y_preds = softmax(r, training=False)
        t_loss = cce_loss_obj(y, y_preds)
        test_loss(t_loss)
        test_acc(y, y_preds)

    if args.write_summary:
        current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        train_log_dir = 'logs/{}/{}/{}/train'.format(model_name, args.data,
                                                     current_time)
        test_log_dir = 'logs/{}/{}/{}/test'.format(model_name, args.data,
                                                   current_time)
        train_summary_writer = tf.summary.create_file_writer(train_log_dir)
        test_summary_writer = tf.summary.create_file_writer(test_log_dir)

    print('Stage 2 training ...')
    model = tf.keras.Sequential([encoder, softmax])
    loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True)

    classifier = TensorFlowV2Classifier(
        model=model,
        loss_object=loss_object,
        train_step=train_step,
        nb_classes=10,
        input_shape=(28, 28, 1),
        clip_values=(0, 1),
    )

    # classifier.fit(x_train, y_train, batch_size=256, nb_epochs=20)

    for epoch in range(args.epoch):
        # Reset the metrics at the start of the next epoch
        train_loss.reset_states()
        train_acc.reset_states()
        test_loss.reset_states()
        test_acc.reset_states()

        for x, y in train_ds2:
            train_step(model, x, y)

        if args.write_summary:
            with train_summary_writer.as_default():
                tf.summary.scalar('loss', train_loss.result(), step=epoch)
                tf.summary.scalar('accuracy', train_acc.result(), step=epoch)

        for x_te, y_te in test_ds:
            test_step(x_te, y_te)

        if args.write_summary:
            with test_summary_writer.as_default():
                tf.summary.scalar('loss', test_loss.result(), step=epoch)
                tf.summary.scalar('accuracy', test_acc.result(), step=epoch)

        template = 'Epoch {}, Loss: {}, Acc: {}, Test Loss: {}, Test Acc: {}'
        print(
            template.format(epoch + 1, train_loss.result(),
                            train_acc.result() * 100, test_loss.result(),
                            test_acc.result() * 100))

    predictions = classifier.predict(x_test)
    print(predictions.shape, y_test.shape)
    accuracy = np.sum(np.argmax(predictions, axis=1) == y_test) / len(y_test)
    print("Accuracy on benign test examples: {}%".format(accuracy * 100))

    print('Stage 3 attacking ...')

    attack = ProjectedGradientDescent(estimator=classifier,
                                      eps=args.eps,
                                      eps_step=args.eps / 3,
                                      max_iter=20)
    x_test_adv = attack.generate(x=x_test)

    print('Stage 4 attacking ...')

    predictions = classifier.predict(x_test_adv)
    accuracy = np.sum(np.argmax(predictions, axis=1) == y_test) / len(y_test)
    print("Accuracy on adversarial test examples: {}%".format(accuracy * 100))

    natual(args.eps)
Esempio n. 7
0
def main(args):
  # Init logger
  if not os.path.isdir(args.save_path):
    os.makedirs(args.save_path)
  log = open(os.path.join(args.save_path, 'log_seed_{}.txt'.format(args.manualSeed)), 'w')
  print_log('save path : {}'.format(args.save_path), log)
  state = {k: v for k, v in args._get_kwargs()}
  print_log(state, log)
  print_log("Random Seed: {}".format(args.manualSeed), log)
  print_log("python version : {}".format(sys.version.replace('\n', ' ')), log)
  print_log("torch  version : {}".format(torch.__version__), log)
  print_log("cudnn  version : {}".format(torch.backends.cudnn.version()), log)

  # Init dataset
  if not os.path.isdir(args.data_path):
    os.makedirs(args.data_path)

  if args.dataset == 'cifar10':
    mean = [x / 255 for x in [125.3, 123.0, 113.9]]
    std = [x / 255 for x in [63.0, 62.1, 66.7]]
  elif args.dataset == 'cifar100':
    mean = [x / 255 for x in [129.3, 124.1, 112.4]]
    std = [x / 255 for x in [68.2, 65.4, 70.4]]
  elif args.dataset == 'imagenet32x32':
    mean = [x / 255 for x in [122.7, 116.7, 104.0]] 
    std = [x / 255 for x in [66.4, 64.6, 68.4]]
  elif args.dataset == 'svhn':
    pass
  else:
    assert False, "Unknow dataset : {}".format(args.dataset)

  if args.dataset == 'cifar10' or args.dataset == 'cifar100' or args.dataset == 'imagenet32x32':
    train_transform = transforms.Compose(
      [transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),
      transforms.Normalize(mean, std)])
    test_transform = transforms.Compose(
      [transforms.ToTensor(), transforms.Normalize(mean, std)])

  if args.dataset == 'cifar10':
    train_data = dset.CIFAR10(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR10(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 10
  elif args.dataset == 'cifar100':
    train_data = dset.CIFAR100(args.data_path, train=True, transform=train_transform, download=True)
    test_data = dset.CIFAR100(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 100
  elif args.dataset == 'svhn':
    def target_transform(target):
      return int(target[0])-1
    train_data = dset.SVHN(args.data_path, split='train', transform=transforms.Compose(
        [transforms.ToTensor(),]), download=True, target_transform=target_transform)
    extra_data = dset.SVHN(args.data_path, split='extra', transform=transforms.Compose(
        [transforms.ToTensor(),]), download=True, target_transform=target_transform)
    train_data.data = np.concatenate([train_data.data, extra_data.data])
    train_data.labels = np.concatenate([train_data.labels, extra_data.labels])
    print(train_data.data.shape, train_data.labels.shape)
    test_data = dset.SVHN(args.data_path, split='test', transform=transforms.Compose([transforms.ToTensor(),]), download=True, target_transform=target_transform)
    num_classes = 10
  elif args.dataset == 'imagenet32x32':
    train_data = IMAGENET32X32(args.data_path, train=True, transform=train_transform, download=True)
    test_data = IMAGENET32X32(args.data_path, train=False, transform=test_transform, download=True)
    num_classes = 1000
  else:
    assert False, 'Do not support dataset : {}'.format(args.dataset)

  train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=True,
                         num_workers=args.workers, pin_memory=True)
  test_loader = torch.utils.data.DataLoader(test_data, batch_size=args.batch_size, shuffle=False,
                        num_workers=args.workers, pin_memory=True)
  M_loader = torch.utils.data.DataLoader(train_data, batch_size=8, shuffle=True,
                         num_workers=args.workers, pin_memory=True)

  print_log("=> creating model '{}'".format(args.arch), log)
  # Init model, criterion, and optimizer
  net = models.__dict__[args.arch](num_classes=num_classes)

  #print_log("=> network:\n {}".format(net), log)

  net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

  # define loss function (criterion) and optimizer
  criterion = torch.nn.CrossEntropyLoss()

  """
  params_skip = []
  params_noskip = []
  skip_lists = ['bn', 'bias']
  for name, param in net.named_parameters():
    if any(name in skip_name for skip_name in skip_lists):
      params_skip.append(param)
    else:
      params_noskip.append(param)
  param_lrs = [{'params':params_skip, 'lr':state['learning_rate']},
		{'params':params_noskip, 'lr':state['learning_rate']}]
  param_lrs = []
  params = []
  names = []
  layers = [3,] + [54,]*3 + [2,]
  for i, (name, param) in enumerate(net.named_parameters()):
    params.append(param)
    names.append(name)
    if len(params) == layers[0]:
      param_dict = {'params': params, 'lr':state['learning_rate']}
      param_lrs.append(param_dict)
      params = []
      names = []
      layers.pop(0)
      
  """ 
  skip_lists = ['bn', 'bias']
  skip_idx = []
  for idx, (name, param) in enumerate(net.named_parameters()):
    if any(skip_name in name for skip_name in skip_lists):
      skip_idx.append(idx)

  param_lrs = net.parameters()
  
  if args.lars:
    optimizer = LARSOptimizer(param_lrs, state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False, steps=state['steps'], eta=state['eta'], skip_idx=skip_idx)
  else:
    optimizer = optim.SGD(param_lrs, state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)

  if args.use_cuda:
    net.cuda()
    criterion.cuda()

  recorder = RecorderMeter(args.epochs)
  # optionally resume from a checkpoint

  avg_norm = []
  if args.lw: 
    for param in net.parameters():
      avg_norm.append(0)

  # Main loop
  print_log('Epoch  Train_Prec@1  Train_Prec@5  Train_Loss  Test_Prec@1  Test_Prec@5  Test_Loss  Best_Prec@1  Time', log)
  for epoch in range(args.start_epoch, args.epochs):

    # train for one epoch
    start_time = time.time()
    train_top1, train_top5, train_loss = train(train_loader, M_loader, net, criterion, optimizer, epoch, log, args, avg_norm)
    training_time = time.time() - start_time

    # evaluate on validation set
    val_top1, val_top5, val_loss = validate(test_loader, net, criterion, log, args)
    recorder.update(epoch, train_loss, train_top1, val_loss, val_top1)

    print('{epoch:d}        {train_top1:.3f}      {train_top5:.3f}     {train_loss:.3f}      {test_top1:.3f}      {test_top5:.3f}    {test_loss:.3f}    {best_top1:.3f}      {time:.3f} '.format(epoch=epoch, time=training_time, train_top1=train_top1, train_top5=train_top5, train_loss=train_loss, test_top1=val_top1, test_top5=val_top5, test_loss=val_loss, best_top1=recorder.max_accuracy(False)))


  log.close()
Esempio n. 8
0
def main():
    best_prec1 = 0

    if not os.path.isdir(args.save_dir):
      os.makedirs(args.save_dir)
    log = open(os.path.join(args.save_dir, '{}.{}.log'.format(args.arch,args.prefix)), 'w')

    # create model
    print_log("=> creating model '{}'".format(args.arch), log)
    model = models.__dict__[args.arch](1000)
    print_log("=> Model : {}".format(model), log)
    print_log("=> parameter : {}".format(args), log)

    if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
      model.features = torch.nn.DataParallel(model.features)
      model.cuda()
    else:
        model = torch.nn.DataParallel(model, device_ids=[x for x in range(args.ngpu)]).cuda()

    # define loss function (criterion) and optimizer
    criterion = nn.CrossEntropyLoss().cuda()

    """
    param_lrs = []
    params = []
    names = []
    layers = [3,] + [12,]+[9,]*2+[12,]+[9,]*3+[12,]+[9,]*5+[12,]+[9,]*2+[2,]
    for i, (name, param) in enumerate(model.named_parameters()):
        params.append(param)
        names.append(name)
        if len(params) == layers[0]:
            param_dict = {'params': params, 'lr':args.learning_rate}
            param_lrs.append(param_dict)
            print(names)
            params = []
            names = []
            layers.pop(0)
    """
    skip_lists = ['bn', 'bias']
    skip_idx = []
    for idx, (name, param) in enumerate(model.named_parameters()):
        if any(skip_name in name for skip_name in skip_lists):
            skip_idx.append(idx)

    param_lrs = model.parameters()
    if args.lars:  
        optimizer = LARSOptimizer(param_lrs, args.learning_rate, momentum=args.momentum,
                weight_decay=args.weight_decay, nesterov=False, steps=args.steps, eta=args.eta, skip_idx=skip_idx)
    else:
        optimizer = optim.SGD(param_lrs, state['learning_rate'], momentum=state['momentum'],
                weight_decay=state['decay'], nesterov=False)

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print_log("=> loading checkpoint '{}'".format(args.resume), log)
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print_log("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), log)
        else:
            print_log("=> no checkpoint found at '{}'".format(args.resume), log)

    cudnn.benchmark = True

    # Data loading code
    traindir = os.path.join(args.data, 'train')
    valdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = datasets.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None)

    val_loader = torch.utils.data.DataLoader(
        datasets.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    M_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=8, shuffle=True,
        num_workers=args.workers, pin_memory=True, sampler=None)

    if args.evaluate:
        validate(val_loader, model, criterion)
        return

    filename = os.path.join(args.save_dir, 'checkpoint.{}.{}.pth.tar'.format(args.arch, args.prefix))
    bestname = os.path.join(args.save_dir, 'best.{}.{}.pth.tar'.format(args.arch, args.prefix))

    avg_norm = []
    if args.lw: 
        for param in model.parameters():
            avg_norm.append(0)

    print_log('Epoch  Train_Prec@1  Train_Prec@5  Train_Loss  Test_Prec@1  Test_Prec@5  Test_Loss  Best_Prec@1  Time', log)
    for epoch in range(args.start_epoch, args.epochs):

        # train for one epoch
        start_time = time.time()
        train_top1, train_top5, train_loss = train(train_loader, M_loader, model, criterion, optimizer, epoch, log, avg_norm)
        training_time = time.time() - start_time

        # evaluate on validation set
        val_top1, val_top5, val_loss = validate(val_loader, model, criterion, log)

        # remember best prec@1 and save checkpoint
        is_best = val_top1 > best_prec1
        best_prec1 = max(val_top1, best_prec1)

        print('{epoch:d}        {train_top1:.3f}      {train_top5:.3f}     {train_loss:.3f}      {test_top1:.3f}      {test_top5:.3f}    {test_loss:.3f}    {best_top1:.3f}      {time:.3f} '.format(epoch=epoch, time=training_time, train_top1=train_top1, train_top5=train_top5, train_loss=train_loss, test_top1=val_top1, test_top5=val_top5, test_loss=val_loss, best_top1=best_prec1))
		
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : optimizer.state_dict(),
        }, is_best, filename, bestname)
        # measure elapsed time

    log.close()