Example #1
0
def gen_proc(comm, iters=10000, i=0, batch_size=4096):
    print("Generator on", i)
    try:
        with open("charset.pkl", 'rb') as f:
            charset = pickle.load(f)
        with open("vocab.pkl", 'rb') as f:
            vocab = pickle.load(f)
        model = mosesvae.VAE(vocab)
        model.load_state_dict(
            torch.load("trained_save_small.pt", map_location='cpu'))
        model = model.cuda(i)

        for _ in range(iters):
            count = 0

            res, _ = model.sample(batch_size)

            smis = []
            for i in range(batch_size):
                count += 1
                try:
                    s = "".join(['[' + charset[sym] + ']' for sym in res[i]])
                    smis.append(s)
                except:
                    print("ERROR!!!")
                    print('res', res[i])
                    print("charset", charset)
            comm.put((smis, count))
            if comm.qsize() > 100:
                time.sleep(20)
    except KeyboardInterrupt:
        print("exiting")
        exit()
Example #2
0
# vocab = mosesvocab.OneHotVocab.from_data(bindings.iloc[:,1].astype(str).tolist())
with open("vocab.pkl", 'rb') as f:
    vocab = pickle.load(f)
bdata = BindingDataSet(bindings)
train_loader = torch.utils.data.DataLoader(
    bdata,
    batch_size=128,
    shuffle=True,
    num_workers=8,
    collate_fn=get_collate_fn_binding(),
    worker_init_fn=mosesvocab.set_torch_seed_to_all_gens,
    pin_memory=True)

n_epochs = 50

model = mosesvae.VAE(vocab)
pt = torch.load("trained_save.pt", map_location='cpu')
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in pt.items():
    name = k[7:]  # remove `module.`
    new_state_dict[name] = v

model.load_state_dict(new_state_dict)
model = model.cuda()
optimizer = optim.Adam(model.binding_model.parameters(), lr=5e-4)
#model.eval()


def _train_epoch_binding(model, epoch, tqdm_data, optimizer=None):
    if optimizer is None:
Example #3
0
        bdata,
        batch_size=128,
        shuffle=False,
        sampler=torch.utils.data.RandomSampler(bdata,
                                               replacement=True,
                                               num_samples=128 * 100),
        num_workers=32,
        collate_fn=get_collate_fn_binding(),
        worker_init_fn=mosesvocab.set_torch_seed_to_all_gens,
        pin_memory=True,
    )


n_epochs = 100

model = mosesvae.VAE(vocab).cuda()
model.apply(init_weights)
binding_optimizer = None

# optimizer = optim.Adam(model.parameters() ,
#                                lr=3*1e-3 )
encoder_optimizer = optim.Adam(model.encoder.parameters(), lr=8e-4)
decoder_optimizer = optim.Adam(model.decoder.parameters(), lr=5e-4)
# model, optimizer = amp.initialize(model, optimizer, opt_level="O1")
# model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)

kl_annealer = KLAnnealer(100)
lr_annealer_d = CosineAnnealingLRWithRestart(encoder_optimizer)
lr_annealer_e = CosineAnnealingLRWithRestart(decoder_optimizer)

model.zero_grad()