Beispiel #1
0
def train():
    dataset.training = True
    print("Loading data...")
    dataloader = DataLoader(
            dataset,
            batch_size=args.batch_size,
            shuffle=True,
            pin_memory=True,
            num_workers=12,
            collate_fn=JTNNCollator(vocab, True),
            drop_last=True,
            worker_init_fn=None)
    dataloader._use_shared_memory = False
    last_loss = sys.maxsize
    print("Beginning Training...")
    for epoch in range(MAX_EPOCH):
        word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
        print("Epoch %d: " % epoch)

        for it, batch in tqdm(enumerate(dataloader),total=len(dataloader)):
            model.zero_grad()
            try:
                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, args.beta)
            except:
                print([t.smiles for t in batch['mol_trees']])
                raise
            loss.backward()
            optimizer.step()

            word_acc += wacc
            topo_acc += tacc
            assm_acc += sacc
            steo_acc += dacc

            cur_loss = loss.item()
            
            if (it + 1) % PRINT_ITER == 0:
                word_acc = word_acc / PRINT_ITER * 100
                topo_acc = topo_acc / PRINT_ITER * 100
                assm_acc = assm_acc / PRINT_ITER * 100
                steo_acc = steo_acc / PRINT_ITER * 100

                print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f, Delta: %.6f" % (
                    kl_div, word_acc, topo_acc, assm_acc, steo_acc, cur_loss, last_loss - cur_loss))
                word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0
                sys.stdout.flush()

            if (it + 1) % 1500 == 0: #Fast annealing
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])
                
            if (it + 1) % 100 == 0:
                torch.save(model.state_dict(),
                            save_path + "/model.iter-%d-%d" % (epoch, it + 1))
                      
            #if last_loss - cur_loss < 1e-5:
            #    break
            last_loss = cur_loss

        scheduler.step()
        print("learning rate: %.6f" % scheduler.get_lr()[0])
        torch.save(model.state_dict(), save_path + "/model.iter-" + str(epoch))