Exemplo n.º 1
0
def main():
    model = WideResNet(numClasses, depth=28, width=2)
    emaModel = WideResNet(numClasses, depth=28, width=2)

    (X_train, Y_train), U_train, (X_test, Y_test) = load_CIFAR_10(labeledExamples=labeledExamples)
    model.build(input_shape=(None, 32, 32, 3))
    emaModel.build(input_shape=(None, 32, 32, 3))

    X_train = tf.data.Dataset.from_tensor_slices({'image': X_train, 'label': Y_train})
    X_test = tf.data.Dataset.from_tensor_slices({'image': X_test, 'label': Y_test})  
    U_train = tf.data.Dataset.from_tensor_slices(U_train)

    optimizer = tf.keras.optimizers.Adam(lr=lr)
    emaModel.set_weights(model.get_weights())

    accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    T = tf.constant(0.5)
    beta = tf.Variable(0., shape=())

    for epoch in range(0, epochs):
        train(X_train, U_train, model, emaModel, optimizer, epoch, T, beta)
        testAccuracy = validate(X_test, emaModel)
        testAccuracy = testAccuracy.result()
        print("Epoch: {} and test accuracy: {}".format(epoch, testAccuracy))
        
        with open('results.txt', 'w') as f:
            f.write("num_label={}, accuracy={}, epoch={}".format(labeledExamples, testAccuracy, epoch))
            f.close()
Exemplo n.º 2
0
def main():
    args = vars(get_args())
    dir_path = os.path.dirname(os.path.realpath(__file__))
    if args['config_path'] is not None and os.path.exists(os.path.join(dir_path, args['config_path'])):
        args = load_config(args)
    start_epoch = 0
    log_path = f'.logs/{args["dataset"]}@{args["labelled_examples"]}'
    ckpt_dir = f'{log_path}/checkpoints'

    datasetX, datasetU, val_dataset, test_dataset, num_classes = fetch_dataset(args, log_path)

    model = WideResNet(num_classes, depth=28, width=2)
    model.build(input_shape=(None, 32, 32, 3))
    optimizer = tf.keras.optimizers.Adam(lr=args['learning_rate'])
    model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
    manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)

    ema_model = WideResNet(num_classes, depth=28, width=2)
    ema_model.build(input_shape=(None, 32, 32, 3))
    ema_model.set_weights(model.get_weights())
    ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
    ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)

    if args['resume']:
        model_ckpt.restore(manager.latest_checkpoint)
        ema_ckpt.restore(manager.latest_checkpoint)
        model_ckpt.step.assign_add(1)
        ema_ckpt.step.assign_add(1)
        start_epoch = int(model_ckpt.step)
        print(f'Restored @ epoch {start_epoch} from {manager.latest_checkpoint} and {ema_manager.latest_checkpoint}')

    train_writer = None
    if args['tensorboard']:
        train_writer = tf.summary.create_file_writer(f'{log_path}/train')
        val_writer = tf.summary.create_file_writer(f'{log_path}/validation')
        test_writer = tf.summary.create_file_writer(f'{log_path}/test')

    # assigning args used in functions wrapped with tf.function to tf.constant/tf.Variable to avoid memory leaks
    args['T'] = tf.constant(args['T'])
    args['beta'] = tf.Variable(0., shape=())
    for epoch in range(start_epoch, args['epochs']):
        xe_loss, l2u_loss, total_loss, accuracy = train(datasetX, datasetU, model, ema_model, optimizer, epoch, args)
        val_xe_loss, val_accuracy = validate(val_dataset, ema_model, epoch, args, split='Validation')
        test_xe_loss, test_accuracy = validate(test_dataset, ema_model, epoch, args, split='Test')

        if (epoch - start_epoch) % 16 == 0:
            model_save_path = manager.save(checkpoint_number=int(model_ckpt.step))
            ema_save_path = ema_manager.save(checkpoint_number=int(ema_ckpt.step))
            print(f'Saved model checkpoint for epoch {int(model_ckpt.step)} @ {model_save_path}')
            print(f'Saved ema checkpoint for epoch {int(ema_ckpt.step)} @ {ema_save_path}')

        model_ckpt.step.assign_add(1)
        ema_ckpt.step.assign_add(1)

        step = args['val_iteration'] * (epoch + 1)
        if args['tensorboard']:
            with train_writer.as_default():
                tf.summary.scalar('xe_loss', xe_loss.result(), step=step)
                tf.summary.scalar('l2u_loss', l2u_loss.result(), step=step)
                tf.summary.scalar('total_loss', total_loss.result(), step=step)
                tf.summary.scalar('accuracy', accuracy.result(), step=step)
            with val_writer.as_default():
                tf.summary.scalar('xe_loss', val_xe_loss.result(), step=step)
                tf.summary.scalar('accuracy', val_accuracy.result(), step=step)
            with test_writer.as_default():
                tf.summary.scalar('xe_loss', test_xe_loss.result(), step=step)
                tf.summary.scalar('accuracy', test_accuracy.result(), step=step)

    if args['tensorboard']:
        for writer in [train_writer, val_writer, test_writer]:
            writer.flush()
def main():
    global datasetX, datasetU, val_dataset, model, ema_model, optimizer, epoch, args
    args = vars(get_args())
    epoch = args['epochs']
    start_epoch = 0
    record_path = f'.logs/{args["dataset"]}@{args["labelled_examples"]}'
    ckpt_dir = f'{record_path}/checkpoints'
    datasetX, datasetU, val_dataset, test_dataset, num_classes = preprocess_dataset(args, record_path)

    model = WideResNet(num_classes, depth=28, width=2)
    model.build(input_shape=(None, 32, 32, 3))
    optimizer = tf.keras.optimizers.Adam(lr=args['learning_rate'])
    model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
    manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)

    ema_model = WideResNet(num_classes, depth=28, width=2)
    ema_model.build(input_shape=(None, 32, 32, 3))
    ema_model.set_weights(model.get_weights())
    ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
    ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)

    if args['resume']:
        model_ckpt.restore(manager.latest_checkpoint)
        ema_ckpt.restore(manager.latest_checkpoint)
        model_ckpt.step.assign_add(1)
        ema_ckpt.step.assign_add(1)
        start_epoch = int(model_ckpt.step)
        print(f'Restored @ epoch {start_epoch} from {manager.latest_checkpoint} and {ema_manager.latest_checkpoint}')

    train_writer = None
    if args['tensorboard']:
        train_writer = tf.summary.create_file_writer(f'{record_path}/train')
        val_writer = tf.summary.create_file_writer(f'{record_path}/validation')
        test_writer = tf.summary.create_file_writer(f'{record_path}/test')

    args['T'] = tf.constant(args['T'])
    args['beta'] = tf.Variable(0., shape=())

    if args['mode']=='tuning':
        params=[datasetX, datasetU, val_dataset, model, ema_model, optimizer, epoch, args]
        Bayesian_Optimization(params)
    else:
        for epoch in range(start_epoch, args['epochs']):
            xe_loss, l2u_loss, total_loss, accuracy = train(datasetX, datasetU, model, ema_model, optimizer, epoch,
                                                            args)
            val_xe_loss, val_accuracy = validate(val_dataset, ema_model, epoch, args, split='Validation')
            test_xe_loss, test_accuracy = validate(test_dataset, ema_model, epoch, args, split='Test')

            if (epoch - start_epoch) % 16 == 0:
                model_save_path = manager.save(checkpoint_number=int(model_ckpt.step))
                ema_save_path = ema_manager.save(checkpoint_number=int(ema_ckpt.step))
                print(f'Saved model checkpoint for epoch {int(model_ckpt.step)} @ {model_save_path}')
                print(f'Saved ema checkpoint for epoch {int(ema_ckpt.step)} @ {ema_save_path}')

            model_ckpt.step.assign_add(1)
            ema_ckpt.step.assign_add(1)

            step = args['val_iteration'] * (epoch + 1)
            if args['tensorboard']:
                with train_writer.as_default():
                    tf.summary.scalar('xe_loss', xe_loss.result(), step=step)
                    tf.summary.scalar('l2u_loss', l2u_loss.result(), step=step)
                    tf.summary.scalar('total_loss', total_loss.result(), step=step)
                    tf.summary.scalar('accuracy', accuracy.result(), step=step)
                with val_writer.as_default():
                    tf.summary.scalar('xe_loss', val_xe_loss.result(), step=step)
                    tf.summary.scalar('accuracy', val_accuracy.result(), step=step)
                with test_writer.as_default():
                    tf.summary.scalar('xe_loss', test_xe_loss.result(), step=step)
                    tf.summary.scalar('accuracy', test_accuracy.result(), step=step)

    if args['tensorboard']:
        for writer in [train_writer, val_writer, test_writer]:
            writer.flush()
Exemplo n.º 4
0
cifar10_test_dataset = tf.data.Dataset.from_tensor_slices({
        'image': x_test,
        'label': y_test
    })

trainX, trainU, validation = split_dataset(cifar10_train_dataset, 4000, 5000,10)
#%%
model = WideResNet(10, depth=28, width=2)
model.build(input_shape=(None, 32, 32, 3))
optimizer = tf.keras.optimizers.Adam(lr=0.01)
# model_ckpt = tf.train.Checkpoint(step=tf.Variable(0), optimizer=optimizer, net=model)
# manager = tf.train.CheckpointManager(model_ckpt, f'{ckpt_dir}/model', max_to_keep=3)

ema_model = WideResNet(10, depth=28, width=2)
ema_model.build(input_shape=(None, 32, 32, 3))
ema_model.set_weights(model.get_weights())
# ema_ckpt = tf.train.Checkpoint(step=tf.Variable(0), net=ema_model)
# ema_manager = tf.train.CheckpointManager(ema_ckpt, f'{ckpt_dir}/ema', max_to_keep=3)

#%%
def train(trainX, trainU, model, ema_model, optimizer, epoch):
    xe_loss_avg = tf.keras.metrics.Mean()
    l2u_loss_avg = tf.keras.metrics.Mean()
    total_loss_avg = tf.keras.metrics.Mean()
    accuracy = tf.keras.metrics.SparseCategoricalAccuracy()

    shuffle_and_batch = lambda dataset: dataset.shuffle(buffer_size=int(1e6)).batch(batch_size=64, drop_remainder=True)

    iteratorX = iter(shuffle_and_batch(trainX))
    iteratorU = iter(shuffle_and_batch(trainU))