def get_graph_model(codec, drop_rate, model_type, output_type='values', num_bins=51):
    if 'conditional' in model_type:
        if 'sparse' in model_type:
            model = CondtionalProbabilityModelSparse(codec.grammar)
        else:
            model = ConditionalModelBlended(codec.grammar)#'sparse' in model_type)
        model.init_encoder_output = lambda x: None
        return model

    # for all other models, start with a GraphEncoder and attach the right head
    encoder = GraphEncoder(grammar=codec.grammar,
                           d_model=512,
                           drop_rate=drop_rate,
                           model_type=model_type)
    if output_type == 'values':
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        model = OneValuePerNodeRuleTransform(model)
    elif 'distributions' in output_type:
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()*num_bins},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        if 'thompson' in output_type:
            model = DistributionPerNodeRuleTransformThompson(model, num_bins=num_bins)
        elif 'softmax' in output_type:
            model = DistributionPerNodeRuleTransformSoftmax(model, num_bins=num_bins, T=10)
    model.init_encoder_output = lambda x: None
    return model
Exemplo n.º 2
0
 def test_batch_crosstalk(self):
     output_shape = [10, 15, 20]
     m = PassthroughModel(output_shape)
     h = MultipleOutputHead(m, [2], drop_rate=0).to(device)
     inp = torch.ones(*output_shape).to(device)
     out1 = h(inp)[0]
     out2 = h(inp[:1])[0]
     assert torch.max((out1[:1] - out2).abs()) < 1e-6
Exemplo n.º 3
0
 def __init__(self, grammar, drop_rate=0.0, d_model=512):
     super().__init__()
     encoder = GraphEncoder(grammar=grammar,
                    d_model=d_model,
                    drop_rate=drop_rate)
     encoder_aggregated = FirstSequenceElementHead(encoder)
     self.discriminator = MultipleOutputHead(encoder_aggregated,
                                             {'p_zinc': 2},
                                             drop_rate=drop_rate).to(device)
Exemplo n.º 4
0
    def __init__(self, grammar, output_spec, drop_rate=0.0, d_model=512):
        super().__init__()
        encoder = GraphEncoder(grammar=grammar,
                               d_model=d_model,
                               drop_rate=drop_rate)
        self.model = MultipleOutputHead(encoder,
                                                output_spec,
                                                drop_rate=drop_rate).to(device)

        # don't support using this model in VAE-style models yet
        self.init_encoder_output = lambda x: None
        self.output_shape = self.model.output_shape
Exemplo n.º 5
0
    def test_full_discriminator_parts_tuple_head(self):
        encoder = GraphEncoder(grammar=gi.grammar, d_model=512, drop_rate=0.0)

        encoder_aggregated = FirstSequenceElementHead(encoder)
        discriminator = MultipleOutputHead(encoder_aggregated, [2],
                                           drop_rate=0).to(device)
        mol_graphs = [
            HyperGraph.from_mol(mol) for mol in get_zinc_molecules(5)
        ]
        out = discriminator(mol_graphs)[0]
        out2 = discriminator(mol_graphs[:1])[0]
        assert out.size(0) == len(mol_graphs)
        assert out.size(1) == 2
        assert len(out.size()) == 2
        assert torch.max((out[0, :] - out2[0, :]).abs()) < 1e-5
Exemplo n.º 6
0
def get_graph_model(codec, drop_rate, model_type, output_type='values', num_bins=51):
    encoder = GraphEncoder(grammar=codec.grammar,
                           d_model=512,
                           drop_rate=drop_rate,
                           model_type=model_type)
    if output_type == 'values':
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        model = OneValuePerNodeRuleTransform(model)
    elif 'distributions' in output_type:
        model = MultipleOutputHead(model=encoder,
                                   output_spec={'node': 1,  # to be used to select next node to expand
                                                'action': codec.feature_len()*num_bins},  # to select the action for chosen node
                                   drop_rate=drop_rate)
        if 'thompson' in output_type:
            model = DistributionPerNodeRuleTransformThompson(model, num_bins=num_bins)
        elif 'softmax' in output_type:
            model = DistributionPerNodeRuleTransformSoftmax(model, num_bins=num_bins, T=10)
    model.init_encoder_output = lambda x: None
    return model
def train_dependencies(EPOCHS=None,
                       BATCH_SIZE=None,
                       max_steps=None,
                       feature_len=None,
                       lr=2e-4,
                       drop_rate=0.0,
                       plot_ignore_initial=1000,
                       save_file=None,
                       preload_file=None,
                       meta=None,
                       languages=None,
                       decoder_type='action',
                       use_self_attention=True,
                       vae=True,
                       target_names=['head'],
                       include_predefined_embedding=True,
                       plot_prefix='',
                       dashboard='policy gradient',
                       ignore_padding=True):

    root_location = os.path.dirname(
        os.path.abspath(inspect.getfile(inspect.currentframe())))
    root_location = root_location + '/../'
    if save_file is not None:
        save_path = root_location + 'pretrained/' + save_file
    else:
        save_path = None

    settings = {}  #get_settings(molecules=molecules,grammar=grammar)

    if EPOCHS is not None:
        settings['EPOCHS'] = EPOCHS
    if BATCH_SIZE is not None:
        settings['BATCH_SIZE'] = BATCH_SIZE

    n_src_vocab = meta[
        'num_tokens'] + 1  # TODO: remove the +1 after next ingest # the same for all languages by construction
    d_model = 512
    if languages is not None:
        multi_embedder = MultiEmbedder(languages, meta['predefined'],
                                       n_src_vocab, d_model)
    else:
        multi_embedder = None

    embedder1 = Embedder(
        max_steps,
        n_src_vocab,  # feature_len
        encode_position=True,
        include_learned=True,
        include_predefined=include_predefined_embedding,
        float_input=False,
        custom_embedder=multi_embedder)
    encoder = TransformerEncoder(n_src_vocab,
                                 max_steps,
                                 dropout=drop_rate,
                                 padding_idx=0,
                                 embedder=embedder1,
                                 use_self_attention=use_self_attention,
                                 d_model=d_model)

    z_size = encoder.output_shape[2]

    embedder2 = Embedder(
        max_steps,
        z_size,  # feature_len
        encode_position=True,
        include_learned=True,
        include_predefined=False,
        float_input=True,
    )

    encoder_2 = TransformerEncoder(z_size,
                                   max_steps,
                                   dropout=drop_rate,
                                   padding_idx=0,
                                   embedder=embedder2,
                                   use_self_attention=use_self_attention,
                                   d_model=d_model)

    decoder = EncoderAsDecoder(encoder_2)

    pre_model_2 = VariationalAutoEncoderHead(encoder=encoder,
                                             decoder=decoder,
                                             z_size=z_size,
                                             return_mu_log_var=False)

    if vae:
        pre_model = pre_model_2
    else:
        pre_model = encoder

    model_outputs = {
        'head': meta['maxlen'],  # head
        'upos': len(meta['upos']),  # part of speech
        'deprel': len(meta['deprel'])  # dependency relationship
    }
    if languages is not None:
        for i in range(len(languages)):
            model_outputs[str(i + 1)] = n_src_vocab  # word
        loss = MultipleCrossEntropyLoss(multi_language='token',
                                        ignore_padding=ignore_padding)
    else:
        model_outputs['token'] = n_src_vocab
        loss = MultipleCrossEntropyLoss(ignore_padding=ignore_padding)

    model = MultipleOutputHead(
        pre_model,
        model_outputs,
        drop_rate=drop_rate,
    )

    model = to_gpu(model)

    if preload_file is not None:
        try:
            preload_path = root_location + 'pretrained/' + preload_file
            model.load_state_dict(torch.load(preload_path))
        except:
            pass

    def model_process_fun(model_out, visdom, n):
        pass

    def get_fitter(model,
                   train_gen,
                   valid_gen,
                   loss_obj,
                   fit_plot_prefix='',
                   model_process_fun=None,
                   lr=None,
                   loss_display_cap=float('inf')):
        nice_params = filter(lambda p: p.requires_grad, model.parameters())
        optimizer = optim.Adam(nice_params, lr=lr)
        scheduler = lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=100)  #.StepLR(optimizer, step_size=100, gamma=0.99)

        if dashboard is not None:
            index_to_lang_ordered = OrderedDict()
            for lang in languages:
                index_to_lang_ordered[meta['predefined'][lang]] = lang
            metric_monitor = MetricPlotter(
                plot_prefix=fit_plot_prefix,
                loss_display_cap=loss_display_cap,
                dashboard_name=dashboard,
                plot_ignore_initial=plot_ignore_initial,
                process_model_fun=model_process_fun,
                extra_metric_fun=partial(language_metrics_for_monitor,
                                         index_to_lang=index_to_lang_ordered),
                smooth_weight=0.9)
        else:
            metric_monitor = None

        checkpointer = Checkpointer(valid_batches_to_checkpoint=1,
                                    save_path=save_path,
                                    save_always=False)

        fitter = fit(train_gen=train_gen,
                     valid_gen=valid_gen,
                     model=model,
                     optimizer=optimizer,
                     scheduler=scheduler,
                     epochs=EPOCHS,
                     loss_fn=loss_obj,
                     batches_to_valid=4,
                     metric_monitor=metric_monitor,
                     checkpointer=checkpointer)

        return fitter

    # TODO: need to be cleaner about dataset creation
    def get_data_loader(dtype, languages):
        if languages is None:
            languages = ['en']
        all_train_data = []
        for lang in languages:
            print('loading', dtype, lang)
            with gzip.open(meta['files'][lang][dtype], 'rb') as f:
                data = pickle.load(f)
                print('loaded', len(data), 'records')
                all_train_data.append(pickle.load(f))
        dataset = ConcatDataset(all_train_data)
        loader = torch.utils.data.DataLoader(dataset,
                                             batch_size=BATCH_SIZE,
                                             shuffle=True,
                                             pin_memory=use_gpu)
        return loader

    train_loader = get_data_loader('train', languages)
    valid_loader = get_data_loader('valid', languages)

    def extract_input(x):
        if include_predefined_embedding:
            return (to_gpu(x['token']), to_gpu(x['embed']))
        else:
            return to_gpu(x['token'])

    def nice_loader(loader):
        return IterableTransform(
            loader, lambda x: (extract_input(x), {
                key: to_gpu(val)
                for key, val in x.items() if key in target_names
            }))

    # the on-policy fitter
    fitter1 = get_fitter(model,
                         nice_loader(train_loader),
                         nice_loader(valid_loader),
                         loss,
                         plot_prefix,
                         model_process_fun=model_process_fun,
                         lr=lr)

    return model, fitter1