Exemple #1
0
    def _setup(self, config):
        print('NaruTrainer config:', config)
        os.chdir(config["cwd"])
        for k, v in config.items():
            setattr(self, k, v)
        self.epoch = 0

        if callable(self.text_eval_corpus):
            self.text_eval_corpus = self.text_eval_corpus()

        # Try to make all the runs the same, except for input orderings.
        torch.manual_seed(0)
        np.random.seed(0)

        assert self.dataset in [
            'dmv', 'dmv-full', 'census',
            'synthetic', 'kdd', 'kdd-full', 'url', 'url-tiny', 'dryad-urls',
            'dryad-urls-small'
        ]
        if self.shuffle_at_data_level:
            data_order_seed = self.order_seed
        else:
            data_order_seed = None
        if self.dataset == 'dmv-full':
            table = datasets.LoadDmv(full=True, order_seed=data_order_seed)
        elif self.dataset == 'dmv':
            table = datasets.LoadDmv(order_seed=data_order_seed)
        elif self.dataset == 'synthetic':
            table = datasets.LoadSynthetic(order_seed=data_order_seed)
        elif self.dataset == 'census':
            table = datasets.LoadCensus(order_seed=data_order_seed)
        elif self.dataset == 'kdd':
            table = datasets.LoadKDD(order_seed=data_order_seed)
        elif self.dataset == 'kdd-full':
            table = datasets.LoadKDD(full=True, order_seed=data_order_seed)
        elif self.dataset == 'url-tiny':
            table = datasets.LoadURLTiny()
        elif self.dataset == 'dryad-urls':
            table = datasets.LoadDryadURLs()
        elif self.dataset == 'dryad-urls-small':
            table = datasets.LoadDryadURLs(small=True)
        self.table = table
        self.oracle = Oracle(
            table, cache_dir=os.path.expanduser("~/oracle_cache"))
        try:
            self.table_bits = Entropy(
                self.table,
                self.table.data.fillna(value=0).groupby(
                    [c.name for c in table.columns]).size(), [2])[0]
        except Exception as e:
            print("Error computing table bits", e)
            self.table_bits = 0  # TODO(ekl) why does dmv-full crash on ec2

        fixed_ordering = None
        if self.special_orders <= 1:
            fixed_ordering = list(range(len(table.columns)))

        if self.entropy_order:
            assert self.num_orderings == 1
            res = []
            for i, c in enumerate(table.columns):
                bits = Entropy(c.name, table.data.groupby(c.name).size(), [2])
                res.append((bits[0], i))
            s = sorted(res, key=lambda b: b[0], reverse=self.reverse_entropy)
            fixed_ordering = [t[1] for t in s]
            print('Using fixed ordering:', '_'.join(map(str, fixed_ordering)))
            print(s)

        if self.order is not None:
            print('Using passed-in order:', self.order)
            fixed_ordering = self.order

        if self.order_seed is not None and not self.shuffle_at_data_level:
            if self.order_seed == "reverse":
                fixed_ordering = fixed_ordering[::-1]
            else:
                rng = np.random.RandomState(self.order_seed)
                rng.shuffle(fixed_ordering)
            print('Using generated order:', fixed_ordering)

        print(table.data.info())
        self.fixed_ordering = fixed_ordering

        table_train = table

        if self.special_orders > 0:
            special_orders = _SPECIAL_ORDERS[self.dataset][:self.special_orders]
            k = len(special_orders)
            seed = self.special_order_seed * 10000
            for i in range(k, self.special_orders):
                special_orders.append(
                    np.random.RandomState(seed + i - k + 1).permutation(
                        np.arange(len(table.columns))))
            print('Special orders', np.array(special_orders))
        else:
            special_orders = []

        if self.use_transformer:
            args = {
                "num_blocks": 4,
                "d_model": 64,
                "d_ff": 256,
                "num_heads": 4,
                "nin": len(table.columns),
                "input_bins": [c.DistributionSize() for c in table.columns],
                "use_positional_embs": True,
                "activation": "gelu",
                "fixed_ordering": fixed_ordering,
                "dropout": False,
                "seed": self.seed,
                "first_query_shared": False,
                "prefix_dropout": self.prefix_dropout,
                "mask_scheme": 0,  # XXX only works for default order?
            }
            args.update(self.transformer_args)
            model = Transformer(**args).to(get_device())
        else:
            model = MakeMade(
                scale=self.fc_hiddens,
                cols_to_train=table.columns,
                seed=self.seed,
                dataset=self.dataset,
                fixed_ordering=fixed_ordering,
                special_orders=special_orders,
                layers=self.layers,
                residual=self.residual,
                embed_size=self.embed_size,
                dropout=self.dropout,
                per_row_dropout=self.per_row_dropout,
                prefix_dropout=self.prefix_dropout,
                fixed_dropout_ratio=self.fixed_dropout_ratio,
                input_no_emb_if_leq=self.input_no_emb_if_leq,
                disable_learnable_unk=self.disable_learnable_unk,
                embs_tied=self.embs_tied)

        child = None

        print(model.nin, model.nout, model.input_bins)
        blacklist = None
        mb = ReportModel(model, blacklist=blacklist)
        self.mb = mb

        if not isinstance(model, Transformer):
            print('applying weight_init()')
            model.apply(weight_init)

        if isinstance(model, Transformer):
            opt = torch.optim.Adam(
                list(model.parameters()) + (list(child.parameters())
                                            if child else []),
                2e-4,
                betas=(0.9, 0.98),
                eps=1e-9,
            )
        else:
            opt = torch.optim.Adam(
                list(model.parameters()) + (list(child.parameters())
                                            if child else []), 2e-4)

        self.train_data = TableDataset(table_train)

        self.model = model
        self.opt = opt

        if self.checkpoint_to_load:
            self.model.load_state_dict(torch.load(self.checkpoint_to_load))
Exemple #2
0
def main():
    train_data = SentenceDataset(args.train_file,
                                 encoding_type=args.encoding_type,
                                 filter_threshold=args.filter_threshold)
    val_data = SentenceDataset(args.val_file,
                               encoding_type=args.encoding_type,
                               filter_threshold=args.filter_threshold)

    train_loader = torch.utils.data.DataLoader(train_data,
                                               args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_data, args.batch_size)

    print(len(train_loader))

    input_dim = len(train_data.vocab.source_vocab)
    output_dim = len(train_data.vocab.target_vocab)
    static = args.embedding_type == 'static'

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    enc_embedding = Embeddings(input_dim, args.hidden_dim, args.max_len,
                               device, static)
    encoder_layer = EncoderLayer(args.hidden_dim, args.num_enc_heads,
                                 args.inner_dim, args.dropout)
    encoder = Encoder(enc_embedding, encoder_layer, args.num_enc_layers,
                      args.dropout)

    dec_embedding = Embeddings(input_dim, args.hidden_dim, args.max_len,
                               device, static)
    decoder_layer = DecoderLayer(args.hidden_dim, args.num_dec_heads,
                                 args.inner_dim, args.dropout)
    decoder = Decoder(output_dim, args.hidden_dim, dec_embedding,
                      decoder_layer, args.num_dec_layers, args.dropout)

    pad_id = train_data.vocab.source_vocab['<pad>']

    model = Transformer(encoder, decoder, pad_id, device)

    print('Transformer has {:,} trainable parameters'.format(
        count_parames(model)))

    if args.load_model is not None:
        model.load(args.load_model)
    else:
        model.apply(init_weights)

    if args.mode == 'test':
        inferencer = Inferencer(model, train_data.vocab, device)
        greedy_out = inferencer.infer_greedy(
            'helo world, I m testin a typo corector')
        print(greedy_out)

    elif args.mode == 'train':
        optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

        loss_function = nn.NLLLoss(ignore_index=pad_id)

        print('Started training...')
        train(model, train_loader, val_loader, optimizer, loss_function,
              device)

    else:
        raise ValueError('Mode not recognized')