Esempio n. 1
0
    valid_data = tf.data.Dataset.from_tensor_slices(valid_files)
    valid_data = valid_data.batch(valid_batch_size)
    valid_data = valid_data.map(load_batch_tf)
    valid_data = valid_data.prefetch(1)

    pretrain_files = [f for i, f in enumerate(train_files) if i % 10 == 0]
    pretrain_data = tf.data.Dataset.from_tensor_slices(pretrain_files)
    pretrain_data = pretrain_data.batch(pretrain_batch_size)
    pretrain_data = pretrain_data.map(load_batch_tf)
    pretrain_data = pretrain_data.prefetch(1)

    ### MODEL LOADING ###
    sys.path.insert(0, os.path.abspath(f'models/{args.model}'))
    import model
    importlib.reload(model)
    model = model.GCNPolicy()
    del sys.path[0]

    ### TRAINING LOOP ###
    optimizer = tf.train.AdamOptimizer(
        learning_rate=lambda: lr)  # dynamic LR trick
    best_loss = np.inf
    for epoch in range(max_epochs + 1):
        log(f"EPOCH {epoch}...", logfile)
        epoch_loss_avg = tfe.metrics.Mean()
        epoch_accuracy = tfe.metrics.Accuracy()

        # TRAIN
        if epoch == 0:
            n = pretrain(model=model, dataloader=pretrain_data)
            log(f"PRETRAINED {n} LAYERS", logfile)
Esempio n. 2
0
    ### MODEL LOADING ###
    sys.path.insert(0, os.path.abspath(f'models/{args.model}'))
    import model
    importlib.reload(model)
    distilled_model = model.Policy()
    del sys.path[0]
    distilled_model.to(device)

    ### TEACHER MODEL LOADING ###
    teacher = None
    if (args.distilled or args.no_e2e):
        sys.path.insert(0, os.path.abspath(f'models/{teacher_model}'))
        import model
        importlib.reload(model)
        teacher = model.GCNPolicy()
        del sys.path[0]
        teacher.restore_state(
            f"trained_models/{args.problem}/{teacher_model}/{args.seed}/best_params.pkl"
        )
        teacher.to(device)
        teacher.eval()

    model = distilled_model

    ### TRAINING LOOP ###
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='min',
                                                           factor=0.2,
                                                           patience=patience,
Esempio n. 3
0
            for seed in seeds:
                rng = np.random.RandomState(seed)
                tf.set_random_seed(rng.randint(np.iinfo(int).max))

                policy = {}
                policy['name'] = policy_name
                policy['type'] = policy_type

                if policy['type'] == 'gcnn':
                    # load model
                    sys.path.insert(
                        0, os.path.abspath(f"models/{policy['name']}"))
                    import model
                    importlib.reload(model)
                    del sys.path[0]
                    policy['model'] = model.GCNPolicy()
                    policy['model'].restore_state(
                        f"trained_models/{args.problem}/{policy['name']}/{seed}/best_params.pkl"
                    )
                    policy['model'].call = tfe.defun(
                        policy['model'].call,
                        input_signature=policy['model'].input_signature)
                    policy['batch_datatypes'] = [
                        tf.float32, tf.int32, tf.float32, tf.float32, tf.int32,
                        tf.int32, tf.int32, tf.int32, tf.int32, tf.float32
                    ]
                    policy['batch_fun'] = load_batch_gcnn
                else:
                    # load feature normalization parameters
                    try:
                        with open(
Esempio n. 4
0
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = f'{args.gpu}'
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    tf.enable_eager_execution(config)
    tf.executing_eagerly()

    # load and assign tensorflow models to policies (share models and update parameters)
    loaded_models = {}
    for policy in branching_policies:
        if policy['type'] == 'gcnn':
            if policy['name'] not in loaded_models:
                sys.path.insert(0, os.path.abspath(f"models/{policy['name']}"))
                import model
                importlib.reload(model)
                loaded_models[policy['name']] = model.GCNPolicy()
                del sys.path[0]
            policy['model'] = loaded_models[policy['name']]

    # load ml-competitor models
    for policy in branching_policies:
        if policy['type'] == 'ml-competitor':
            try:
                with open(f"{policy['model']}/normalization.pkl", 'rb') as f:
                    policy['feat_shift'], policy['feat_scale'] = pickle.load(f)
            except:
                policy['feat_shift'], policy['feat_scale'] = 0, 1

            with open(f"{policy['model']}/feat_specs.pkl", 'rb') as f:
                policy['feat_specs'] = pickle.load(f)
Esempio n. 5
0
        if policy['type'] not in loaded_models:
            sys.path.insert(0, os.path.abspath(f"models/{policy['type']}"))
            import model
            importlib.reload(model)
            loaded_models[policy['type']] = model.Policy()
            del sys.path[0]
            loaded_models[policy['type']].to(device)
            loaded_models[policy['type']].eval()

        if (policy['teacher_type'] is not None
                and policy['teacher_type'] not in loaded_models):
            sys.path.insert(
                0, os.path.abspath(f"models/{policy['teacher_type']}"))
            import model
            importlib.reload(model)
            loaded_models[policy['teacher_type']] = model.GCNPolicy()
            del sys.path[0]
            loaded_models[policy['teacher_type']].to(device)
            loaded_models[policy['teacher_type']].eval()

        policy['model'] = loaded_models[policy['type']]
        policy['teacher_model'] = loaded_models[policy[
            'teacher_type']] if policy['teacher_type'] is not None else None

    print("running SCIP...")

    fieldnames = [
        'problem',
        'device',
        'policy',
        'seed',