def create_noise(batched_data, model):
    noise_batch = []
    for x, f, y in batched_data:
        noise = model.decode(x, f)
        noise = [
            ns[1:] for ns in noise
        ]  # the first one is SOS. remove it to match with sentence dim
        noise = Process.process_noise(
            noise)  #pad it again, make same len as sentence. make tensor.
        noise_batch.append(noise)
    return noise_batch
Ejemplo n.º 2
0
def train_generator(discriminator, generator, batched_data, model,
                    G_optimizer):
    freeze_net(discriminator)
    unfreeze_net(generator)

    timer = time.time()

    for ei in range(cfg.GEN_ITR):
        count = 0
        loss = 0
        for x, f, y in batched_data:
            G_optimizer.zero_grad()
            noise = model.decode(x, f)
            noise = [
                ns[1:] for ns in noise
            ]  # the first one is SOS. remove it to match with sentence dim
            noise = Process.process_noise(
                noise)  #pad it again, make same len as sentence. make tensor.
            noise.requires_grad_(True)
            #print('shape of noise {} {}'.format(noise.shape, x.shape))
            #print(f)
            fake_data = generator(x, f, noise, y)
            fake_data = [rs[1:] for rs in fake_data
                         ]  # removing SOS to match sentence dim
            fake_data = Process.process_noise(fake_data)
            validity_gen = discriminator(x, f, fake_data)
            gen_cost = validity_gen.mean()
            gen_cost.backward(mone)
            gen_cost = -gen_cost
            G_optimizer.step()
            count = count + 1
            loss = loss + gen_cost
        loss = loss / count

    timer = time.time() - timer
    unfreeze_net(discriminator)
    return generator, timer, loss
def train_disc(discriminator,
               generator,
               batched_data,
               model,
               D_optimizer,
               from_gen=True):
    """ train the disc """
    freeze_net(generator)
    freeze_net(model)
    unfreeze_net(discriminator)
    timer = time.time()
    for ei in range(cfg.CRITIC_ITR):
        d_loss_avg = 0
        count = 0
        print('@disc epoch{}'.format(ei))
        for x, f, y in batched_data:
            prob = np.random.uniform(0, 1)
            if prob > 0:
                count = count + 1
                noise = model.decode(x, f)
                noise = [
                    ns[1:] for ns in noise
                ]  # the first one is SOS. remove it to match with sentence dim
                noise = Process.process_noise(
                    noise
                )  #pad it again, make same len as sentence. make tensor.

                #print('shape of noise {} {}'.format(noise.shape, x.shape))
                #print(f)
                if from_gen:
                    fake_data = generator.decode(x, f, noise)
                    fake_data = [rs[1:] for rs in fake_data
                                 ]  # removing SOS to match sentence dim
                    fake_data = Process.process_noise(
                        fake_data)  #pad it again, make tensor.
                    validity_gen = discriminator(x, f, fake_data.detach())
                else:
                    validity_gen = discriminator(x, f, noise.detach())
                y = y[:, 1:]  #removing SOS to match sentence len
                #validity_gen = discriminator(x,f,fake_data.detach())

                validity_act = discriminator(x, f, y)  # act sen, act tag

                ##-- find the hardest negative----##
                validity_wrong = discriminator(x, f,
                                               y[0].expand_as(y)).squeeze(-1)
                #print('worng score:{}'.format(validity_wrong))
                for i in range(1, y.shape[0]):
                    vw = discriminator(x, f, y[i].expand_as(y)).squeeze(-1)
                    #print(vw)
                    validity_wrong = torch.cat((validity_wrong, vw),
                                               0)  #act tag, WRONG sen

                validity_wrong = validity_wrong.view(y.shape[0], -1)
                validity_wrong = validity_wrong.t()
                mask = eye(validity_wrong.size(0)) > .5
                validity_wrong = validity_wrong.masked_fill_(mask, 0)
                #print(sc_gen);
                validity_wrong = validity_wrong.max(1)[0]
                validity_wrong = validity_wrong.unsqueeze(-1)
                y_real = ones(validity_act.shape)
                y_fake = zeros(validity_gen.shape)

                D_fake_loss = BCE_loss(validity_gen, y_fake)
                D_hn_loss = BCE_loss(validity_wrong, y_fake)
                D_real_loss = BCE_loss(validity_act, y_real)

                #D_train_loss = D_real_loss + D_fake_loss
                D_train_loss = D_real_loss + 0.5 * (D_fake_loss + D_hn_loss)
                D_optimizer.zero_grad()
                D_train_loss.backward()
                D_optimizer.step()

                d_loss_avg += scalar(D_train_loss.detach())

        d_loss_avg /= count
        timer = time.time() - timer
        print('discriminator loss {}'.format(d_loss_avg))
    unfreeze_net(generator)
    return discriminator, timer, d_loss_avg, D_optimizer
def train_generator(discriminator,
                    generator,
                    batched_data,
                    model,
                    G_optimizer,
                    epochs=cfg.GEN_ITR,
                    isolation=False):
    #freeze_net(discriminator)
    unfreeze_net(generator)

    timer = time.time()
    for ei in range(epochs):
        G_loss_avg = 0
        count = 0
        print('@generator epoch{}'.format(ei))
        for x, f, y in batched_data:
            prob = np.random.uniform(0, 1)
            if prob > 0:
                count = count + 1
                noise = model.decode(x, f)
                noise = [
                    ns[1:] for ns in noise
                ]  # the first one is SOS. remove it to match with sentence dim
                noise = Process.process_noise(
                    noise
                )  #pad it again, make same len as sentence. make tensor.
                #print('shape of noise {} {}'.format(noise.shape, x.shape))
                #print(f)
                loss = generator(x, f, noise, y)
                fake_data = generator.decode(x, f, noise)
                fake_data = [rs[1:] for rs in fake_data
                             ]  # removing SOS to match sentence dim
                fake_data = Process.process_noise(fake_data)
                validity_gen = discriminator(x, f, fake_data)
                #y_real = ones(validity_gen.shape)
                #bc_loss = BCE_loss(validity_gen, y_real)
                validity_gen = 1 - validity_gen.squeeze(-1)
                #print('shape of loss {}'.format(loss))
                #print('shape of score {}'.format(validity_gen))
                #print('shape of bc_loss {} '.format(bc_loss .shape))
                #G_train_loss = BCE_loss(validity_gen, y_real)+loss
                if isolation:
                    G_train_loss = torch.mean(loss)
                else:
                    G_train_loss = torch.dot(validity_gen, loss) / x.shape[0]

                #print('shape of mul score {}'.format(G_train_loss))

                #G_train_loss = Variable(G_train_loss, requires_grad = True)

                G_optimizer.zero_grad()
                G_train_loss.backward()
                G_optimizer.step()

                G_loss_avg += scalar(G_train_loss.detach())

        G_loss_avg /= count
        print('generator loss {}'.format(G_loss_avg))
        timer = time.time() - timer
    unfreeze_net(discriminator)
    return generator, timer, G_loss_avg, G_optimizer
    # Binary Cross Entropy loss
    #argv[1] == training filename
    #argv[2] == model name to be save
    #argv[3] == GPU
    if len(sys.argv) > 3:
        torch.cuda.set_device(int(sys.argv[3]))
    fname = EMBEDDING
    filename = os.path.join(cfg.TRAINDED_MODEL_PATH, sys.argv[2])
    load_kwargs = {"vocab_size": 400000, "dim": 300}
    w = Embedding.from_glove(fname, **load_kwargs)
    dL = DataLoader()
    wdic = WordDictionary(w)
    tdic = TagDictionary()
    dL.readSRLData(sys.argv[1], wdic, tdic, False)
    batched_data = Process.create_batch_data(dL.sentences, cfg.BATCH_SIZE,
                                             wdic, tdic)
    print('saving dictionaries')

    save_word_to_idx(cfg.WORD_2_IDX_PATH, wdic)
    save_tag_to_idx(cfg.TAG_2_IDX_PATH, tdic)

    model = lstm_crf(len(wdic.word2idx), len(tdic.tag2idx), False)
    epoch = load_checkpoint(CRF_READ_GAN, model)
    model.eval()

    generator = Generator(len(wdic.word2idx), len(tdic.tag2idx), True,
                          wdic.getWeight(), tdic)
    #generator.set_crf(model.crf)
    discriminator = Discriminator(len(wdic.word2idx), len(tdic.tag2idx), True,
                                  wdic.getWeight())
Ejemplo n.º 6
0
    lr = learning_rate * (0.1**(epoch // 10))
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr


if __name__ == '__main__':

    fname = EMBEDDING
    filename = os.path.join(cfg.TRAINDED_MODEL_PATH, sys.argv[2])
    load_kwargs = {"vocab_size": 400000, "dim": 300}
    w = Embedding.from_glove(fname, **load_kwargs)
    dL = DataLoader()
    wdic = WordDictionary(w)
    tdic = TagDictionary()
    dL.readSRLData(sys.argv[1], wdic, tdic, False)
    batched_data = Process.create_batch_data(dL.sentences, cfg.BATCH_SIZE,
                                             wdic, tdic)
    print('saving dictionaries')

    save_word_to_idx(cfg.WORD_2_IDX_PATH, wdic)
    save_tag_to_idx(cfg.TAG_2_IDX_PATH, tdic)

    model = lstm_crf(len(wdic.word2idx), len(tdic.tag2idx), False)
    epoch = load_checkpoint(CRF_READ_GAN, model)
    model.eval()

    generator = Generator(len(wdic.word2idx), len(tdic.tag2idx), True,
                          wdic.getWeight(), tdic)
    generator.set_crf(model.crf)
    discriminator = Discriminator(len(wdic.word2idx), len(tdic.tag2idx), True,
                                  wdic.getWeight())
Ejemplo n.º 7
0
    tdic = load_tag_to_idx(cfg.TAG_2_IDX_PATH)

    model = lstm_crf(len(wdic.word2idx), len(tdic.tag2idx), False)
    epoch = load_checkpoint(CRF_READ_GAN, model)

    generator = Generator(len(wdic.word2idx), len(tdic.tag2idx), False)
    model.eval()

    epoch = load_checkpoint(sys.argv[3], generator)
    print(generator.crf.parameters())
    #print(tdic.idx2tag)
    dL = DataLoader()
    dL.readSRLTestData(sys.argv[1], wdic, tdic, True)
    #print(dL.sentences[0])
    batched_data = Process.create_batch_data_dev(
        dL.sentences, cfg.BATCH_SIZE, wdic, tdic
    )  #create_batch_data_testing(dL.sentences, cfg.BATCH_SIZE, wdic, tdic)
    #print(batch_data)
    all_result = [[]] * len(dL.sentences)
    for x, f, y, indices in batched_data:
        noise = model.decode(x, f)
        noise = [ns[1:] for ns in noise]  # the first one is SOS
        noise = Process.process_noise(
            noise)  #pad it again, make same len as sentence. make tensor.

        #print('shape of noise {} {}'.format(noise.shape, x.shape))
        #print(f)
        result = generator(x, f, noise, y)
        result = [[tdic.getTag(j) for j in result[i]]
                  for i in range(len(result))]
        for i, idx in enumerate(indices):
Ejemplo n.º 8
0
def train_disc(discriminator, generator, batched_data, model, D_optimizer):
    """ train the disc """
    freeze_net(generator)
    unfreeze_net(discriminator)
    timer = time.time()

    for ei in range(cfg.CRITIC_ITR):
        count = 0
        loss = 0
        for x, f, y in batched_data:
            count = count + 1
            noise = model.decode(x, f)
            noise = [
                ns[1:] for ns in noise
            ]  # the first one is SOS. remove it to match with sentence dim
            noise = Process.process_noise(
                noise)  #pad it again, make same len as sentence. make tensor.

            #print('shape of noise {} {}'.format(noise.shape, x.shape))
            #print(f)
            fake_data = generator(x, f, noise, y)
            fake_data = [rs[1:] for rs in fake_data
                         ]  # removing SOS to match sentence dim
            fake_data = Process.process_noise(
                fake_data)  #pad it again, make tensor.
            y = y[:, 1:]  #removing SOS to match sentence len

            validity_gen = discriminator(x, f, fake_data.detach())
            disc_fake = validity_gen.mean()
            validity_act = discriminator(x, f, y)  # act sen, act tag
            disc_real = validity_act.mean()
            '''
            ##-- find the hardest negative----##
            validity_wrong = discriminator(x,f,y[0].expand_as(y)).squeeze(-1)
            #print('worng score:{}'.format(validity_wrong))
            for i in range(1,y.shape[0]):
                vw = discriminator(x,f,y[i].expand_as(y)).squeeze(-1)
                #print(vw)
                validity_wrong = torch.cat((validity_wrong, vw),0) #act tag, WRONG sen

            validity_wrong = validity_wrong.view(y.shape[0],-1)
            validity_wrong = validity_wrong.t()
            mask = eye(validity_wrong .size(0)) > .5
            sc_wrong = validity_wrong .masked_fill_(mask, 0)
            #print(sc_gen);
            sc_wrong = sc_wrong.max(1)[0]
            '''
            gradient_penalty = calc_gradient_penalty(discriminator, y,
                                                     fake_data, x, f)
            D_optimizer.zero_grad()
            disc_cost = disc_fake - disc_real + gradient_penalty
            disc_cost.backward()
            w_dist = disc_fake - disc_real
            D_train_loss.backward()
            D_optimizer.step()
            loss = loss + disc_cost

        loss = loss / count

    timer = time.time() - timer

    unfreeze_net(generator)
    return discriminator, timer, loss
Ejemplo n.º 9
0
if __name__ == '__main__':
    #--load things---##
    print('usage: python predict.py devfile goldfile mdoel')
    wdic = load_word_to_idx(cfg.WORD_2_IDX_PATH)
    tdic = load_tag_to_idx(cfg.TAG_2_IDX_PATH)
    model = lstm_crf(len(wdic.word2idx), len(tdic.tag2idx), False)
    model.eval()

    epoch = load_checkpoint(sys.argv[3], model)
    print(model.crf.parameters())
    #print(tdic.idx2tag)
    dL = DataLoader()
    dL.readSRLTestData(sys.argv[1], wdic, tdic, True)
    #print(dL.sentences[0])
    batched_data = Process.create_batch_data_dev(
        dL.sentences, cfg.BATCH_SIZE, wdic, tdic
    )  #create_batch_data_testing(dL.sentences, cfg.BATCH_SIZE, wdic, tdic)
    #print(batch_data)
    all_result = [[]] * len(dL.sentences)
    for x, f, _, indices in batched_data:
        result = model.decode(x, f)
        result = [[tdic.getTag(j) for j in result[i]]
                  for i in range(len(result))]
        for i, idx in enumerate(indices):
            slen = len(result[i])
            all_result[idx] = result[i][1:slen -
                                        1] if USE_SE else result[i][0:slen - 1]
            #all_result[idx] = bio_to_se( all_result[idx])

    #print(all_result)
    evaluate(all_result, sys.arg[2])