示例#1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_spec", default="std", help="Image specification to load")
    parser.add_argument("--early_stopping_epoch", default=10, type=int, help="Number of epochs to test early stopping after")
    parser.add_argument("--use_augment", action='store_true', help="Whether to perform real-time augmentation")
    parser.add_argument("--base_filters", default=16, type=int, help="Number of base filters to use")
    parser.add_argument("--inception_per_block", default=3, type=int, help="Number of inception layers per block")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size used for training")
    parser.add_argument("--learning_rate", default=1e-3, type=float, help="Learning rate used for fine-tuning")
    parser.add_argument("--restore_exp", help="Experiment to restore from")
    parser.add_argument("--train_all", action="store_true", help="Whether to only train all variables")
    parser.add_argument("--dataset", default="tox21", help="Dataset to train on")

    args = parser.parse_args()

    layers_per_block = args.inception_per_block
    inception_blocks = {"A": layers_per_block, "B": layers_per_block, "C": layers_per_block}

    mode = get_task_mode(args.dataset)

    ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
    DIRNAME = os.path.join(os.environ.get("SCRATCH", ROOT_DIR), "deepchem-data")
    hparams_dir = "filters_{}_blocklayers_{}_imgspec_{}".format(args.base_filters, layers_per_block, args.img_spec)
    restore_dir = os.path.join(DIRNAME, "chemception", hparams_dir, "best-models", args.restore_exp + "/")
    logger.info("Restore dir: {}".format(restore_dir))

    load_fn = loaders[args.dataset]
    tasks, dataset, transformers = load_fn(featurizer="smiles2img", data_dir=DIRNAME, save_dir=DIRNAME, img_spec=args.img_spec, split="stratified")

    metric_type = metric_types[args.dataset]

    task_averager = np.mean
    if len(tasks) == 1:
        task_averager = None

    metric = dc.metrics.Metric(metric_type, task_averager=task_averager, mode=mode, verbose=False)
    train, valid, test = dataset

    # Setup directory for experiment
    exp_name = dt.now().strftime("%d-%m-%Y--%H-%M-%S")
    model_dir_1 = os.path.join(DIRNAME, args.dataset + "-finetune", "chemception", hparams_dir, exp_name)

    # Optimizer and logging
    optimizer = RMSProp(learning_rate=args.learning_rate)

    logger.info("Args used: {}".format(args))
    logger.info("Num_tasks: {}".format(len(tasks)))

    tensorboard = True
    if tf.executing_eagerly():
        tensorboard = False

    # Dummy model is based on pretrained model, only used for restore and variable copy
    dummy_model = ChemCeption(n_tasks=100, img_spec=args.img_spec,
                        inception_blocks=inception_blocks,
                        base_filters=args.base_filters, augment=args.use_augment,
                        model_dir=None, mode="regression",
                        n_classes=2, batch_size=args.batch_size,
                        optimizer=None, tensorboard=tensorboard,
                        tensorboard_log_frequency=100)

    finetune_model = ChemCeption(n_tasks=len(tasks), img_spec=args.img_spec,
                        inception_blocks=inception_blocks,
                        base_filters=args.base_filters, augment=args.use_augment,
                        model_dir=model_dir_1, mode=mode,
                        n_classes=2, batch_size=args.batch_size,
                        optimizer=optimizer, tensorboard=tensorboard,
                        tensorboard_log_frequency=100)

    finetune_model.load_from_pretrained(source_model=dummy_model, assignment_map=None, model_dir=restore_dir, include_top=False)

    train, valid, test = dataset

    logger.info("Created model dir at {}".format(model_dir_1))
    best_models_dir_1 = os.path.join(DIRNAME, args.dataset + "-finetune", "chemception", hparams_dir, "best-models", exp_name)

    loss_old = compute_loss_on_valid(valid, finetune_model, tasks, mode=mode)
    logger.info("Saving best model so far")
    finetune_model.save_checkpoint(model_dir=best_models_dir_1)

    train_scores = finetune_model.evaluate(train, [metric], [])
    valid_scores = finetune_model.evaluate(valid, [metric], [])
    test_scores = finetune_model.evaluate(test, [metric], [])

    logger.info("Train-{}: {}".format(metric.name, train_scores[metric.name]))
    logger.info("Valid-{}: {}".format(metric.name, valid_scores[metric.name]))
    logger.info("Test-{}: {}".format(metric.name, test_scores[metric.name]))

    for rep_num in range(10):
        logger.info("Training model for {} epochs.".format(args.early_stopping_epoch))
        if args.train_all:
            finetune_model.fit(train, nb_epoch=args.early_stopping_epoch, checkpoint_interval=0)
        else:
            var_list = finetune_model.model.trainable_variables[-2:]
            finetune_model.fit(train, nb_epoch=args.early_stopping_epoch, variables=var_list, checkpoint_interval=0)
        loss_new = compute_loss_on_valid(valid, finetune_model, tasks, mode=mode, verbose=False)

        train_scores = finetune_model.evaluate(train, [metric], [])
        valid_scores = finetune_model.evaluate(valid, [metric], [])
        test_scores = finetune_model.evaluate(test, [metric], [])

        logger.info("Train-{}: {}".format(metric.name, train_scores[metric.name]))
        logger.info("Valid-{}: {}".format(metric.name, valid_scores[metric.name]))
        logger.info("Test-{}: {}".format(metric.name, test_scores[metric.name]))

        logger.info("Computed loss on validation set after {} epochs: {}".format(args.early_stopping_epoch, loss_new))
        if loss_new > loss_old:
            logger.info("No improvement in validation loss. Enforcing early stopping.")
            break

        logger.info("Saving best model so far")
        finetune_model.save_checkpoint(model_dir=best_models_dir_1)
        loss_old = loss_new
示例#2
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument("--img_spec", default="std", help="Image specification to load")
    parser.add_argument("--early_stopping_epoch", default=10, type=int, help="Number of epochs to check early stopping after for")
    parser.add_argument("--use_augment", action='store_true', help="Whether to perform real-time augmentation")
    parser.add_argument("--base_filters", default=16, type=int, help="Number of base filters to use")
    parser.add_argument("--inception_per_block", default=3, type=int, help="Number of inception layers per block")
    parser.add_argument("--batch_size", default=32, type=int, help="Batch size used for training")
    parser.add_argument("--learning_rate", default=1e-4, type=float, help="Learning rate used.")
    parser.add_argument("--dataset", default="tox21", help="Dataset to train on.")

    args = parser.parse_args()

    layers_per_block = args.inception_per_block
    inception_blocks = {"A": layers_per_block, "B": layers_per_block, "C": layers_per_block}

    mode = get_task_mode(args.dataset)

    DIRNAME = os.path.join(os.environ.get("SCRATCH", "./"), "deepchem-data")
    load_fn = loaders[args.dataset]
    tasks, dataset, transformers = load_fn(featurizer="smiles2img", data_dir=DIRNAME, save_dir=DIRNAME, img_spec=args.img_spec, split="stratified")

    metric_type = metric_types[args.dataset]

    task_averager = np.mean
    if len(tasks) == 1:
        task_averager = None

    metric = dc.metrics.Metric(metric_type, task_averager=task_averager, mode=mode, verbose=False)
    train, valid, test = dataset

    # Setup directory for experiment
    exp_name = dt.now().strftime("%d-%m-%Y--%H-%M-%S")
    hparams_dir = "filters_{}_blocklayers_{}_imgspec_{}".format(args.base_filters, layers_per_block, args.img_spec)
    model_dir_1 = os.path.join(DIRNAME, args.dataset, "chemception", hparams_dir, exp_name)

    # Optimizer and logging
    optimizer = RMSProp(learning_rate=args.learning_rate)

    logger.info("Dataset used: {}".format(args.dataset))
    logger.info("Args used: {}".format(args))
    logger.info("Num_tasks: {}".format(len(tasks)))

    ###### TRAINING FIRST PART WITH CONSTANT LEARNING RATE ###############
    model = ChemCeption(n_tasks=len(tasks), img_spec=args.img_spec,
                        inception_blocks=inception_blocks,
                        base_filters=args.base_filters, augment=args.use_augment,
                        model_dir=model_dir_1, mode=mode,
                        n_classes=2, batch_size=args.batch_size,
                        optimizer=optimizer, tensorboard=True,
                        tensorboard_log_frequency=100)
    model._ensure_built()

    train, valid, test = dataset

    logger.info("Created model dir at {}".format(model_dir_1))
    best_models_dir_1 = os.path.join(DIRNAME, args.dataset, "chemception", hparams_dir, "best-models", exp_name)
    logger.info("Saving best model so far")
    model.save_checkpoint(model_dir=best_models_dir_1)

    loss_old = compute_loss_on_valid(valid, model, tasks, mode=mode)

    train_scores = model.evaluate(train, [metric], [])
    valid_scores = model.evaluate(valid, [metric], [])
    test_scores = model.evaluate(test, [metric], [])

    logger.info("Train-{}: {}".format(metric.name, train_scores[metric.name]))
    logger.info("Valid-{}: {}".format(metric.name, valid_scores[metric.name]))
    logger.info("Test-{}: {}".format(metric.name, test_scores[metric.name]))

    for rep_num in range(2):
        logger.info("Training model for {} epochs.".format(args.early_stopping_epoch))
        model.fit(train, nb_epoch=args.early_stopping_epoch, checkpoint_interval=0)
        loss_new = compute_loss_on_valid(valid, model, tasks, mode=mode, verbose=False)

        train_scores = model.evaluate(train, [metric], [])
        valid_scores = model.evaluate(valid, [metric], [])
        test_scores = model.evaluate(test, [metric], [])

        logger.info("Train-{}: {}".format(metric.name, train_scores[metric.name]))
        logger.info("Valid-{}: {}".format(metric.name, valid_scores[metric.name]))
        logger.info("Test-{}: {}".format(metric.name, test_scores[metric.name]))

        logger.info("Computed loss on validation set after {} epochs: {}".format(args.early_stopping_epoch, loss_new))
        if loss_new > loss_old:
            logger.info("No improvement in validation loss. Enforcing early stopping.")
            break

        logger.info("Saving best model so far")
        model.save_checkpoint(model_dir=best_models_dir_1)
        loss_old = loss_new

    ######### TRAINING SECOND PART WITH DECAYING LEARNING RATE

    # Optimizer and logging
    decay_steps = args.early_stopping_epoch * train.y.shape[0] // args.batch_size
    logger.info("Decay steps: {}".format(decay_steps))

    lr = ExponentialDecay(initial_rate=args.learning_rate, decay_rate=0.92, decay_steps=decay_steps, staircase=True)
    optimizer = RMSProp(learning_rate=lr)

    # Setup directory for experiment
    exp_name = dt.now().strftime("%d-%m-%Y--%H-%M-%S")
    hparams_dir = "filters_{}_blocklayers_{}_imgspec_{}".format(args.base_filters, layers_per_block, args.img_spec)

    new_model = ChemCeption(n_tasks=len(tasks), img_spec=args.img_spec,
                            inception_blocks=inception_blocks,
                            base_filters=args.base_filters, augment=args.use_augment,
                            model_dir=model_dir_1, mode=mode,
                            n_classes=2, batch_size=args.batch_size,
                            optimizer=optimizer, tensorboard=True,
                            tensorboard_log_frequency=100)
    new_model.restore(model_dir=best_models_dir_1)

    best_models_dir_2 = os.path.join(DIRNAME, args.dataset, "chemception", hparams_dir, "best-models", exp_name)
    logger.info("Created best model dir for second stage at {}".format(best_models_dir_2))

    loss_old = compute_loss_on_valid(valid, new_model, tasks, mode=mode)

    for rep_num in range(2):
        logger.info("Training model for {} epochs.".format(args.early_stopping_epoch))
        new_model.fit(train, nb_epoch=args.early_stopping_epoch, checkpoint_interval=0)
        loss_new = compute_loss_on_valid(valid, new_model, tasks, mode=mode, verbose=False)

        train_scores = new_model.evaluate(train, [metric], [])
        valid_scores = new_model.evaluate(valid, [metric], [])
        test_scores = new_model.evaluate(test, [metric], [])

        logger.info("Train-{}: {}".format(metric.name, train_scores[metric.name]))
        logger.info("Valid-{}: {}".format(metric.name, valid_scores[metric.name]))
        logger.info("Test-{}: {}".format(metric.name, test_scores[metric.name]))

        logger.info("Computed loss on validation set after {} epochs: {}".format(args.early_stopping_epoch, loss_new))
        if loss_new > loss_old:
            logger.info("No improvement in validation loss. Enforcing early stopping.")
            break

        logger.info("Saving best model so far")
        new_model.save_checkpoint(model_dir=best_models_dir_2)
        loss_old = loss_new