def f_test_example(model, tokenizer, w2i, i2w):
    raw_text = "since the wii wasn't being used much anymore in the living room , i thought i'd try to take it to the bedroom so i can watch hulu and stuff like that . this cable made it possible."
    input_text = f_raw2vec(tokenizer, raw_text, w2i, i2w)
    length_text = len(input_text)
    length_text = [length_text]
    print("length_text", length_text)

    input_tensor = torch.LongTensor(input_text)
    print('input_tensor', input_tensor)
    input_tensor = input_tensor.unsqueeze(0)
    if torch.is_tensor(input_tensor):
        input_tensor = to_var(input_tensor)

    length_tensor = torch.LongTensor(length_text)
    print("length_tensor", length_tensor)
    # length_tensor = length_tensor.unsqueeze(0)
    if torch.is_tensor(length_tensor):
        length_tensor = to_var(length_tensor)

    print("*" * 10)
    print("->" * 10,
          *idx2word(input_tensor, i2w=i2w, pad_idx=w2i['<pad>']),
          sep='\n')
    logp, mean, logv, z = model(input_tensor, length_tensor)

    mean = mean.unsqueeze(0)
    # print("mean", mean)
    # print("z", z)

    samples, z = model.inference(z=mean)
    print("<-" * 10,
          *idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']),
          sep='\n')
示例#2
0
def main(args):

    with open(args.data_dir + '/poems.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(vocab_size=len(w2i),
                        sos_idx=w2i['<sos>'],
                        eos_idx=w2i['<eos>'],
                        pad_idx=w2i['<pad>'],
                        unk_idx=w2i['<unk>'],
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional,
                        condition_size=0)

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(
        torch.load(args.load_checkpoint, map_location=torch.device('cpu')))
    print("Model loaded from %s" % (args.load_checkpoint))

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()
    samples, z = model.inference(n=args.num_samples)
    print('----------SAMPLES----------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
    # while True:
    #     samples, z = model.inference(n=1, condition=torch.Tensor([[1, 0, 0, 0, 0, 0, 0]]).cuda())
    #     poem = idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])[0]
    #     if 'love' in poem:
    #         breakpoint()

    z1 = torch.randn([args.latent_size]).numpy()
    z2 = torch.randn([args.latent_size]).numpy()
    z = to_var(
        torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    # samples, _ = model.inference(z=z, condition=torch.Tensor([[1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0], [1, 0, 0, 0, 0, 0, 0]]).cuda())
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
示例#3
0
def main(args):

    with open(args.data_dir+'/ptb.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s"%(args.load_checkpoint))

    if torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()

#     samples, z = model.inference(n=args.num_samples)
#     print('----------SAMPLES----------')
#     print(idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']))

    z1 = torch.randn([args.latent_size]).numpy()
    z2 = torch.randn([args.latent_size]).numpy()
    z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']))
    
    model.load_state_dict(torch.load('bin/2019-May-16-04:24:16/E10.pytorch'))
    z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']))
示例#4
0
def main(args):

    # load checkpoint
    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    saved_dir_name = args.load_checkpoint.split('/')[2]
    params_path = './saved_vae_models/' + saved_dir_name + '/model_params.json'

    if not os.path.exists(params_path):
        raise FileNotFoundError(params_path)

    # load params
    params = load_model_params_from_checkpoint(params_path)

    # create and load model
    model = SentenceVae(**params)
    print(model)
    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s" % args.load_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    # load the vocab of the chosen dataset

    if (model.dataset == 'yelp'):
        print("Yelp dataset used!")

        with open(args.data_dir + '/yelp/yelp.vocab.json', 'r') as file:
            vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    samples, z = model.inference(n=args.num_samples)
    print('----------SAMPLES----------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    z1 = torch.randn([params['latent_size']]).numpy()
    z2 = torch.randn([params['latent_size']]).numpy()
    z = to_var(
        torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)

    print('-------INTERPOLATION-------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
示例#5
0
文件: train.py 项目: timbmg/DIAL-LV
    def inference(model, train_dataset, split, n=10, m=3):
        """ Executes the model in inference mode and returns string of inputs and corresponding
        generations.

        Parameters
        ----------
        model : DIAL-LV
            The DIAL-LV model.
        train_dataset : Dataset
            Training dataset to draw random input samples from.
        split : str
            'train', 'valid' or 'test', to enable/disable word_dropout.
        n : int
            Number of samples to draw.
        m : int
            Number of response generations.

        Returns
        -------
        string, string
            Two string, each consiting of n utterances. `Prompts` contains the input sequence and
            `replies` the generated response sequence.

        """

        random_input_idx = np.random.choice(np.arange(0, len(train_dataset)), 10, replace=False).astype('int64')
        random_inputs = np.zeros((n, args.max_input_length)).astype('int64')
        random_inputs_length = np.zeros(n)
        for i, rqi in enumerate(random_input_idx):
            random_inputs[i] = train_dataset[rqi]['input_sequence']
            random_inputs_length[i] = train_dataset[rqi]['input_length']

        input_sequence = to_var(torch.from_numpy(random_inputs).long())
        input_length = to_var(torch.from_numpy(random_inputs_length).long())
        prompts = idx2word(input_sequence.data, train_dataset.i2w, train_dataset.pad_idx)

        replies = list()
        if split == 'train':
            model.eval()
        for i in range(m):
            replies_ = model.inference(input_sequence, input_length)
            replies.append(idx2word(replies_, train_dataset.i2w, train_dataset.pad_idx))

        if split == 'train':
            model.train()

        return prompts, replies
示例#6
0
def main(args):
    with open(args.data_dir + '/ptb.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(vocab_size=len(w2i),
                        sos_idx=w2i['<sos>'],
                        eos_idx=w2i['<eos>'],
                        pad_idx=w2i['<pad>'],
                        unk_idx=w2i['<unk>'],
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s" % args.load_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    # samples, z = model.inference(n=args.num_samples)
    # print('----------SAMPLES----------')
    # print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    # z_ = torch.randn([args.latent_size]).numpy()
    # input_sent = "the n stock specialist firms on the big board floor the buyers and sellers of last resort who were criticized after the n crash once again could n't handle the selling pressure"
    input_sent = "looking for a job was one of the most anxious periods of my life and is for most people"
    batch_input = torch.LongTensor([[w2i[i]
                                     for i in input_sent.split()]]).cuda()
    batch_len = torch.LongTensor([len(input_sent.split())]).cuda()
    input_mean = model(batch_input, batch_len, output_mean=True)
    z_ = input_mean.cpu().detach().numpy()
    print(z_.shape)
    # z2 = torch.randn([args.latent_size]).numpy()
    for i in range(args.latent_size):
        print(f"-------Dimension {i}------")
        z1, z2 = z_.copy(), z_.copy()
        z1[i] -= 0.5
        z2[i] += 0.5
        z = to_var(
            torch.from_numpy(interpolate(start=z1, end=z2, steps=5)).float())
        samples, _ = model.inference(z=z)
        print('-------INTERPOLATION-------')
        print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
示例#7
0
def get_sents_and_tags(samples, i2w, w2i):
    """
    preprocesses sentences, gets parses, and then returns phrase_tag occourance
    """
    samples = idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])
    samples = remove_bad_samples(samples)
    parses = get_parses(samples)
    tags = find_tags_in_parse(PHRASE_TAGS, parses)

    return(samples, tags)
示例#8
0
def main(args):

    with open(args.data_dir + '/snli_yelp/snli_yelp.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    if not os.path.exists(args.load_params):
        raise FileNotFoundError(args.load_params)

    # load params
    params = load_model_params_from_checkpoint(args.load_params)

    # create model
    model = SentenceVaeMultiTask(**params)

    print(model)
    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s" % args.load_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    # print(params['latent_size'])
    # exit()

    samples, z = model.inference(n=args.num_samples)
    print('----------SAMPLES----------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    z1 = torch.randn([params['latent_size']]).numpy()
    z2 = torch.randn([params['latent_size']]).numpy()
    z = to_var(
        torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
示例#9
0
def get_interpolations(vae, sample_start, sample_end, args):
    model = vae['model']
    tokenizer = vae['tokenizer']
    w2i = vae['w2i']
    i2w = vae['i2w']
    # Initialize semantic loss
    # sl = Semantic_Loss()

    start_encode = tokenizer.encode(sample_start)
    end_encode = tokenizer.encode(sample_end)
    with torch.no_grad():
        z1 = model._encode(**start_encode)
        z1_hidden = z1['z'].cpu()[0]

        z2 = model._encode(**end_encode)
        z2_hidden = z2['z'].cpu()[0]

    z_hidden = to_var(torch.from_numpy(interpolate(start=z1_hidden, end=z2_hidden, steps=args.steps)).float())

    if args.rnn_type == "lstm":
        z1_cell_state = z1['z_cell_state'].cpu()[0].squeeze()
        z2_cell_state = z2['z_cell_state'].cpu()[0].squeeze()

        # print(z1_cell_state.shape)

        z_cell_states = \
            to_var(torch.from_numpy(interpolate(start=z1_cell_state, end=z2_cell_state, steps=args.steps)).float())

        samples, _ = model.inference(z=z_hidden, z_cell_state=z_cell_states)
    else:
        samples, _ = model.inference(z=z_hidden, z_cell_state=None)
    # print('-------INTERPOLATION-------')

    interpolated_sentences = idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])
    # For each sentence, get the perplexity and show it
    # for sentence in interpolated_sentences:
        # print(sentence + "\t\t" + str(sl.get_perplexity(sentence)))
        # print(sentence)

    return interpolated_sentences
示例#10
0
def main(args):

    with open(args.data_dir + '/ptb.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceJMVAE(vocab_size=len(w2i),
                          sos_idx=w2i['<sos>'],
                          eos_idx=w2i['<eos>'],
                          pad_idx=w2i['<pad>'],
                          max_sequence_length=args.max_sequence_length,
                          embedding_size=args.embedding_size,
                          rnn_type=args.rnn_type,
                          hidden_size=args.hidden_size,
                          word_dropout=args.word_dropout,
                          latent_size=args.latent_size,
                          num_layers=args.num_layers,
                          bidirectional=args.bidirectional,
                          label_sequence_len=args.label_sequence_len)

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    print('summary')
    model.load_state_dict(
        torch.load(args.load_checkpoint,
                   map_location=lambda storage, loc: storage))
    #print("Model loaded from s%"%(args.load_checkpoint))

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()

    #Create decoding dict for attributes
    predicate_dict = defaultdict(set)

    #Load in attribute types from trainset
    df = pd.read_csv('./e2e-dataset/trainset.csv', delimiter=',')
    tuples = [tuple(x) for x in df.values]
    #Parse all  the attribute inputs
    for t in tuples:
        for r in t[0].split(','):
            r_ind1 = r.index('[')
            r_ind2 = r.index(']')
            rel = r[0:r_ind1].strip()
            rel_val = r[r_ind1 + 1:r_ind2]
            predicate_dict[rel].add(rel_val)

    #Sort attribute inputs for consistensy for each run
    od = OrderedDict(sorted(predicate_dict.items()))
    for key in od.keys():
        od[key] = sorted(od[key])
    predicate_dict = od
    rel_lens = [len(predicate_dict[p]) for p in predicate_dict.keys()]
    print('lrl', len(rel_lens))
    rel_list = list(predicate_dict.keys())
    rel_val_list = list(predicate_dict.values())

    X = np.zeros((len(tuples), sum(rel_lens)), dtype=np.int)
    #Generate input in matrix form
    int_to_rel = defaultdict()
    for i, tup in enumerate(tuples):
        for relation in tup[0].split(','):
            rel_name = relation[0:relation.index('[')].strip()
            rel_value = relation[relation.index('[') + 1:-1].strip()
            name_ind = rel_list.index(rel_name)
            value_ind = list(predicate_dict[rel_name]).index(rel_value)
            j = sum(rel_lens[0:name_ind]) + value_ind
            #print(relation,j)
            int_to_rel[j] = relation
            X[i, j] = 1.

    ii = 22
    print(tuples[ii])
    print(X[ii])
    indices_att = [int_to_rel[i] for i, x in enumerate(X[ii]) if x == 1]
    print(indices_att)
    y_datasets = OrderedDict()
    print('----------DECODINGS----------')

    w_datasets, y_datasets = load_e2e(False, args.max_sequence_length, 1)
    print('a')
    batch_size = 10
    data_loader = DataLoader(
        dataset=w_datasets['test'],  #y_datasets[split],
        batch_size=batch_size,
        shuffle=False,
        num_workers=cpu_count(),
        pin_memory=torch.cuda.is_available())

    batch_num = random.randint(0, len(data_loader) - 1)
    max_max_word = 0
    false_neg_list = {'label': [], 'joint': [], 'sentence': []}
    false_pos_list = {'label': [], 'joint': [], 'sentence': []}
    acc_list = {'label': [], 'joint': [], 'sentence': []}
    loss_list = {'label': [], 'joint': [], 'sentence': []}
    perfect = {'label': [], 'joint': [], 'sentence': []}

    NLL = torch.nn.NLLLoss(size_average=False,
                           ignore_index=w_datasets['train'].pad_idx)
    BCE = torch.nn.BCELoss(size_average=False)

    def loss_fn_plus(logp, logp2, target, target2, length, mean, logv, mean_w,
                     logv_w, mean_y, logv_y, anneal_function, step, k, x0):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))
        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)

        #Cross entropy loss
        BCE_loss = BCE(logp2, target2)
        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

        return NLL_loss, BCE_loss, KL_loss

    for iteration, batch in enumerate(data_loader):
        batch_size = len(batch['input'])
        #if not iteration==batch_num:
        #    continue

        #Make sure all word indeces within range of trained model
        if torch.max(
                batch['input']
        ) > 1683:  #[33,36,72,73,83,129,157,158,165,177,181,201,274,352,459]:
            print(iteration)
            continue
        batch['labels'] = batch['labels'].float()
        #for k, v in batch.items():
        #    if torch.is_tensor(v):
        #        batch[k] = to_var(v)
        sorted_lengths, sorted_idx = torch.sort(batch['length'],
                                                descending=True)
        input_sequence = batch['input'][sorted_idx]
        #print(input_sequence)
        #import pdb; pdb.set_trace()
        if torch.max(input_sequence) > max_max_word:
            max_max_word = torch.max(input_sequence)
        print(max_max_word)
        #print(iteration,torch.max(input_sequence))
        input_embedding = model.embedding(input_sequence)
        label_sequence = batch['labels'][sorted_idx]
        param_y = model.encode_y(label_sequence)
        param_joint = model.encode_joint(input_embedding, label_sequence,
                                         sorted_lengths)
        param_w = model.encode_w(input_embedding, sorted_lengths)

        _, reversed_idx = torch.sort(sorted_idx)

        params = [(param_y, 'label'), (param_joint, 'joint'),
                  (param_w, 'sentence')]
        print_stuff = False
        if iteration == batch_num:
            print_stuff = True
        for param, name in params:
            if print_stuff:
                print('----------Reconstructions from ' + name +
                      ' data----------')
            mu_i, sig_i = param
            z_w = model.sample_z(batch_size, mu_i, sig_i)
            #_, y_decoded = model.decode_joint(z_w,input_embedding,sorted_lengths,sorted_idx)
            y_decoded = model.decode_to_y(z_w, sorted_lengths, sorted_idx)
            samples_w, z_w2 = model.inference(n=args.num_samples, z=z_w)
            inp_samp = idx2word(input_sequence, i2w=i2w, pad_idx=w2i['<pad>'])
            dec_samp = idx2word(samples_w, i2w=i2w, pad_idx=w2i['<pad>'])
            for iter in zip(inp_samp, dec_samp, label_sequence, y_decoded):
                if print_stuff:
                    print('True', iter[0], '\n', 'Pred', iter[1])

                true_att = [
                    int_to_rel[i] for i, x in enumerate(iter[2].round())
                    if x == 1.
                ]
                pred_att = [
                    int_to_rel[i] for i, x in enumerate(iter[3].round())
                    if x == 1.
                ]
                if print_stuff:
                    print('True', true_att, '\nPred', pred_att)
                mistakes = 0
                false_pos = 0
                false_neg = 0
                for t, p in zip(iter[2].round(), iter[3].round()):
                    if t != p:
                        mistakes += 1
                        if p == 0.:
                            false_neg += 1
                        elif p == 1.:
                            false_pos += 1
                if print_stuff:
                    print('Mistakes:', mistakes)
                    print('Accuracy:',
                          (len(iter[2]) - mistakes) / len(iter[2]))
                    print('False pos:', false_pos, '\tFalse neg:', false_neg,
                          '\n')
                false_neg_list[name].append(false_pos)
                false_pos_list[name].append(false_neg)
                acc_list[name].append((len(iter[2]) - mistakes) / len(iter[2]))

            #logp, logp2, mean, logv, z, mean_w, logv_w, mean_y, logv_y = model(batch['input'],batch['labels'], batch['length'])
            loss_current = [
                0, 0, 0
            ]  #loss_fn_plus(logp, logp2, batch['target'], batch['labels'],
            #batch['length'], mean, logv, mean_w, logv_w, mean_y, logv_y, 'logistic', 1000, 0.0025, 2500)

            loss_list[name].append(loss_current)
        print_stuff = False

        #NLL_loss, BCE_loss, KL_loss, KL_loss_w, KL_loss_y, KL_weight
        #print('Attributes')
    for _, name in params:
        print(name + ':')
        print('mmw', max_max_word)
        print('avg false neg',
              sum(false_neg_list[name]) / len(false_neg_list[name]))
        print('avg false pos',
              sum(false_pos_list[name]) / len(false_pos_list[name]))
        print('avg accuracy', sum(acc_list[name]) / len(acc_list[name]))
        print('avg NLL',
              sum([x[0] for x in loss_list[name]]) / len(loss_list[name]))
        print('avg BCE',
              sum([x[1] for x in loss_list[name]]) / len(loss_list[name]))
        print('avg KL joint div',
              sum([x[2] for x in loss_list[name]]) / len(loss_list[name]))

    samples, z = model.inference(n=args.num_samples)
    print()
    print('----------SAMPLES----------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    z1 = torch.randn([args.latent_size]).numpy()
    z2 = torch.randn([args.latent_size]).numpy()
    z = to_var(
        torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)
    #Make uncoupled label decoder
    #
    #
    print('-------INTERPOLATION-------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
示例#11
0
def main(checkpoint_fname, args):

    #=============================================================================#
    # Load model
    #=============================================================================#

    if not os.path.exists(checkpoint_fname):
        raise FileNotFoundError(checkpoint_fname)

    model = torch.load(checkpoint_fname)

    print("Model loaded from %s" % (checkpoint_fname))

    if torch.cuda.is_available():
        device = torch.device('cuda')
        torch.set_default_tensor_type('torch.cuda.FloatTensor')
    else:
        device = torch.device('cpu')

    model = model.to(device)

    # update max_sequence length
    if args.max_sequence_length > 0:
        model.max_sequence_length = args.max_sequence_length
    else:
        # else use model training value
        args.max_sequence_length = model.max_sequence_length

    if args.sample_mode:
        model.sample_mode = args.sample_mode

    compute_temperature = False
    if args.temperature > 0.0:
        model.temperature = args.temperature
    elif args.temperature == 0.0:
        model.temperature = 1.0 / model.latent_size * 0.5
    elif args.temperature < 0.0:
        compute_temperature = True

    model.eval()

    base_fname = "{base_fname}-{split}-{seed}-marginal{marginal}-mcmc{mcmc}".format(
        base_fname=os.path.splitext(checkpoint_fname)[0],
        split=args.split,
        seed=args.seed,
        marginal=int(model.marginal),
        mcmc=args.mcmc,
    )

    data_fname = base_fname + ".csv"
    log_fname = base_fname + ".txt"

    print("log_fname = {log_fname}".format(log_fname=log_fname))

    #=============================================================================#
    # Log results
    #=============================================================================#

    log_fh = open(log_fname, "w")

    def print_log(*args, **kwargs):
        """
        Print to screen and log file.
        """

        print(*args, **kwargs, file=log_fh)
        return print(*args, **kwargs)

    #=============================================================================#
    # Load data
    #=============================================================================#
    print_log('----------INFO----------\n')
    print_log("checkpoint_fname = {checkpoint_fname}".format(
        checkpoint_fname=checkpoint_fname))
    print_log("\nargs = {args}\n".format(args=args))
    print_log("\nmodel.args = {args}\n".format(args=model.args))
    params_num = len(torch.nn.utils.parameters_to_vector(model.parameters()))
    print_log("\n parameters number = {params_num_h} [{params_num}]".format(
        params_num_h=aux.millify(params_num),
        params_num=params_num,
    ))

    dataset_name = model.args.dataset.lower()
    split = args.split

    dataset = aux.load_dataset(
        dataset_name=dataset_name,
        split=split,
        args=args,
        enforce_sos=False,
    )

    vocab = aux.load_vocab(
        dataset_name=dataset_name,
        args=args,
    )
    w2i, i2w = vocab['w2i'], vocab['i2w']

    if (args.batch_size <= 0):
        args.batch_size = args.num_samples

    args.batch_size = min(len(dataset), args.batch_size)
    args.num_samples = min(len(dataset), args.num_samples)

    # collect all stats
    stats = {}

    #=============================================================================#
    # Evaluate model
    #=============================================================================#

    if args.test:
        print("Testing model")
        print("data_fname = {data_fname}".format(data_fname=data_fname))

        print_log('----------EVALUATION----------\n')

        tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
        ) else torch.Tensor

        data_loader = DataLoader(
            dataset=dataset,
            batch_size=args.batch_size,
            shuffle=True,
            drop_last=True,
        )

        # entropy
        nll = torch.nn.NLLLoss(reduction="none", ignore_index=dataset.pad_idx)

        def loss_fn(
            logp,
            target,
            length,
            mean,
            logv,
            z,
            nll=nll,
            N=1,
            eos_idx=dataset.eos_idx,
            pad_idx=dataset.pad_idx,
        ):
            batch_size = target.size(0)

            # do not count probability over <eos> toekn
            eos_I = (target == eos_idx)
            target[eos_I] = pad_idx

            # cut-off unnecessary padding from target, and flatten
            target = target[:, :torch.max(length).item()].contiguous().view(-1)

            # dataset size
            N = torch.tensor(N).type_as(logp)

            log_p_x_given_z = logp.view(-1, logp.size(2))

            q_z_given_x = torch.distributions.Normal(
                loc=mean,
                scale=torch.exp(0.5 * logv),
            )
            log_q_z_given_x = q_z_given_x.log_prob(z).sum(-1)

            # conditional entropy
            H_p_x_given_z = nll(log_p_x_given_z, target).view(
                (batch_size, -1)).sum(-1)

            if model.args.mim:
                log_p_z = log_q_z_given_x - torch.log(N)
            else:
                p_z = model.get_prior()
                log_p_z = p_z.log_prob(z)

            if len(log_p_z.shape) > 1:
                log_p_z = log_p_z.sum(-1)

            # marginal entropy
            CE_q_p_z = (-log_p_z)
            H_q_z_given_x = (-log_q_z_given_x)
            # KL divergence between q(z|x) and p(z)
            KL_q_p = CE_q_p_z - H_q_z_given_x

            # NLL upper bound
            if model.args.mim:
                # MELBO
                H_p_x = H_p_x_given_z + torch.log(N)
            else:
                # ELBO
                H_p_x = H_p_x_given_z + KL_q_p

            return dict(
                H_q_z_given_x=H_q_z_given_x,
                CE_q_p_z=CE_q_p_z,
                H_p_x_given_z=H_p_x_given_z,
                H_p_x=H_p_x,
                KL_q_p=KL_q_p,
            )

        def test_model(model=model,
                       data_loader=data_loader,
                       desc="",
                       plot_dist=False,
                       base_fname=base_fname,
                       compute_temperature=False):
            """
            Compute various quantities for a model
            """
            word_count = 0
            N = len(data_loader.dataset)
            B = N // args.batch_size
            tracker = defaultdict(tensor)

            all_z = []
            for iteration, batch in tqdm(
                    enumerate(data_loader),
                    desc=desc,
                    total=B,
            ):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                all_z.append(z.detach().cpu().numpy())
                # Model evaluation
                loss_dict = loss_fn(
                    logp,
                    batch['target'],
                    batch['length'],
                    mean,
                    logv,
                    z,
                    N=N,
                )

                # aggregate values
                for k, v in loss_dict.items():
                    tracker[k] = torch.cat((tracker[k], v.detach().data))

                # subtract <eos> token from word count
                word_count = (
                    word_count + batch['length'].sum().type_as(z) -
                    (batch['target'] == dataset.eos_idx).sum().type_as(z))

                # BLEU
                if args.test_bleu:
                    recon, _, recon_l = model.decode(z=model.encode(
                        batch['input'], batch['length']), )
                    batch_bleu = []
                    for d, dl, r, rl in zip(batch['target'], batch['length'],
                                            recon, recon_l):
                        cur_bleu = bleu_score.sentence_bleu(
                            references=[d[:dl].tolist()],
                            hypothesis=r[:rl].tolist(),
                            weights=(1.0, ),
                        )
                        batch_bleu.append(cur_bleu)

                    tracker["BLEU"] = torch.cat(
                        (tracker["BLEU"], torch.tensor(batch_bleu)))
            if model.latent_size > 300:
                H_p_z = -1.0
            else:
                H_p_z = ee.entropy(np.concatenate(all_z[:1000], axis=0),
                                   base=np.e)
            H_normal_z = float(model.latent_size) / 2 * (1 + np.log(2 * np.pi))

            H_p_x_given_z_acc = tracker["H_p_x_given_z"].sum()
            H_p_x_acc = tracker["H_p_x"].sum()

            ppl_x_given_z = torch.exp(H_p_x_given_z_acc / word_count)
            ppl_x = torch.exp(H_p_x_acc / word_count)

            tracker["H_p_z"] = torch.tensor(H_p_z)
            tracker["H_normal_z"] = torch.tensor(H_normal_z)
            tracker["ppl_x_given_z"] = ppl_x_given_z
            tracker["ppl_x"] = ppl_x

            if plot_dist:
                print_log(
                    "Saving images and data to base_fname = {base_fname}*".
                    format(base_fname=base_fname, ))
                # Plot distribution of values
                for k, v in tracker.items():
                    if v.numel() > 1:
                        v = v.cpu().detach().numpy()
                        fig = plt.figure()
                        plt.hist(v, density=True, bins=50)
                        # mean
                        plt.axvline(np.mean(v), lw=3, ls="--", c="k")
                        if k == "H_p_x":
                            # sample entropy
                            plt.axvline(np.log(len(v)), lw=3, ls="-", c="k")

                        fig.savefig(base_fname + "-" + k + ".png",
                                    bbox_inches='tight')
                        plt.close(fig)
                        # save data
                        np.save(base_fname + "-" + k + ".npy", v)

            if compute_temperature:
                model.temperature = np.std(all_z)

            return {
                k: v.detach().mean().unsqueeze(0)
                for k, v in tracker.items()
            }

        tracker = defaultdict(tensor)
        for epoch in range(args.test_epochs):
            cur_tracker = test_model(
                model=model,
                data_loader=data_loader,
                desc="Batch [{:d} / {:d}]".format(epoch + 1, args.test_epochs),
                plot_dist=(epoch == 0),
                compute_temperature=(compute_temperature and (epoch == 0)),
            )

            for k, v in cur_tracker.items():
                tracker[k] = torch.cat([tracker[k], cur_tracker[k]])

        for k, v in tracker.items():
            v_mean = v.detach().cpu().mean().numpy()
            v_std = v.detach().cpu().std().numpy()
            print_log("{k} = {v_mean} +/- {v_std}".format(
                k=k,
                v_mean=v_mean,
                v_std=v_std,
            ))
            stats[k] = [k, v_mean, v_std]

        print_log("")

    #=============================================================================#
    # Save stats
    #=============================================================================#
    if len(stats):
        with open(data_fname, 'w') as fh:
            writer = csv.writer(fh)
            for k, row in stats.items():
                if isinstance(row, list):
                    writer.writerow(row)
                else:
                    writer.writerow([k, row])

    #=============================================================================#
    # Sample
    #=============================================================================#
    if args.test_sample:
        print_log("\n model.temperature = {temperature:e}\n".format(
            temperature=model.temperature))

        aux.reset_seed(args.seed)

        batches_in_samples = max(1, args.num_samples // args.batch_size)
        all_samples = []
        # all_z = []
        for b in range(batches_in_samples):
            samples, z, length = model.sample(
                n=args.batch_size,
                z=None,
                mcmc=args.mcmc,
            )

            all_samples.append(samples.detach())

        samples = torch.cat(all_samples, dim=0)[:args.num_samples]

        print_log('----------SAMPLES----------\n')
        for s in idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']):
            print_log("SAMP: {}\n".format(s))

    #=============================================================================#
    # Reconstruction
    #=============================================================================#
    aux.reset_seed(args.seed)

    # Reconstruct starting from <sos>
    dataset.enforce_sos = True

    data_loader = DataLoader(
        dataset=dataset,
        batch_size=args.batch_size,
        shuffle=True,
    )

    # collect non-empty sentences
    data_iter = iter(data_loader)
    samples = {'input': [], 'target': [], 'length': []}

    for data in data_iter:
        for k, v in data.items():
            if torch.is_tensor(v):
                data[k] = to_var(v)

        for i in range(args.batch_size):
            if (data["length"][i] >= args.min_sample_length) and (
                    data["length"][i] <= args.max_sample_length):
                if args.no_unk_sample:
                    if ((data["input"][i] == dataset.unk_idx).sum() >= 1):
                        continue

                for k, v in data.items():
                    if k in samples:
                        samples[k].append(v[i])

            if len(samples["length"]) >= args.num_samples:
                break

        if len(samples["length"]) >= args.num_samples:
            break

    for k, v in samples.items():
        samples[k] = torch.stack(v)[:args.num_samples]

    z, mean, std = model.encode(samples['input'],
                                samples['length'],
                                return_mean=True,
                                return_std=True)
    z = z.detach()
    mean = mean.detach()
    mean_recon, _ = model.inference(z=mean)
    mean_recon = mean_recon.detach()
    z_recon, _ = model.inference(z=z)
    z_recon = z_recon.detach()
    pert, _ = model.inference(z=z + torch.randn_like(z) * args.pert * std)
    pert = pert.detach()

    print_log('----------RECONSTRUCTION----------\n')
    for i, (d, mr, zr, p) in enumerate(
            zip(
                idx2word(samples["input"], i2w=i2w, pad_idx=w2i['<pad>']),
                idx2word(mean_recon, i2w=i2w, pad_idx=w2i['<pad>']),
                idx2word(z_recon, i2w=i2w, pad_idx=w2i['<pad>']),
                idx2word(pert, i2w=i2w, pad_idx=w2i['<pad>']),
            )):

        print_log("DATA: {}".format(d))
        print_log("MEAN RECON: {}".format(mr))
        print_log("Z RECON: {}".format(zr))
        print_log("Z PERT: {}".format(p))

        print_log("\n")

    #=============================================================================#
    # Interpolation
    #=============================================================================#
    if args.test_interp:
        args.num_samples = min(args.num_samples, z.shape[0])
        for i in range(args.num_samples - 1):
            z1 = z[i].cpu().numpy()
            z2 = z[i + 1].cpu().numpy()
            z_L2 = np.sqrt(np.sum((z1 - z2)**2))
            z_interp = to_var(
                torch.from_numpy(interpolate(start=z1, end=z2,
                                             steps=8)).float())
            samples_interp, _ = model.inference(z=z_interp)
            sample0 = samples["input"][i:i + 1]
            sample1 = samples["input"][i + 1:i + 2]

            print_log(
                '-------INTERPOLATION [ z L2 = {zL2:.3f} ] {src} -> {dst} -------\n\n[ {sample} ]\n'
                .format(
                    src=i,
                    dst=i + 1,
                    sample=idx2word(sample0, i2w=i2w, pad_idx=w2i['<pad>'])[0],
                    zL2=z_L2,
                ))
            print_log(*idx2word(samples_interp, i2w=i2w, pad_idx=w2i['<pad>']),
                      sep='\n\n')
            print_log('\n[ {sample} ]\n'.format(sample=idx2word(
                sample1, i2w=i2w, pad_idx=w2i['<pad>'])[0], ))
def main(args):
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    total_steps = (len(datasets["train"]) // args.batch_size) * args.epochs
    print("Train dataset size", total_steps)

    def kl_anneal_function(anneal_function, step):
        if anneal_function == 'identity':
            return 1
        if anneal_function == 'linear':
            if args.warmup is None:
                return 1 - (total_steps - step) / total_steps
            else:
                warmup_steps = (total_steps / args.epochs) * args.warmup
                return 1 - (warmup_steps - step
                            ) / warmup_steps if step < warmup_steps else 1.0

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'] = torch.cat(
                    (tracker['negELBO'], loss.data.unsqueeze(0)))

                if args.tensorboard_logging:
                    neg_elbo = (recon_loss + KL_loss) / batch_size
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      neg_elbo.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], recon_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs,
                         torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  torch.mean(tracker['negELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)

    if args.num_samples:
        torch.cuda.empty_cache()
        model.eval()
        with torch.no_grad():
            print(f"Generating {args.num_samples} samples")
            generations, _ = model.inference(n=args.num_samples)
            vocab = datasets["train"].i2w

            print(
                "Sampled latent codes from z ~ N(0, I), generated sentences:")
            for i, generation in enumerate(generations, start=1):
                sentence = [vocab[str(word.item())] for word in generation]
                print(f"{i}:", " ".join(sentence))
示例#13
0
文件: train.py 项目: kaletap/Bert-VAE
def main(args):
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    RANDOM_SEED = 42

    dataset = load_dataset("yelp_polarity", split="train")
    TRAIN_SIZE = len(dataset) - 2_000
    VALID_SIZE = 1_000
    TEST_SIZE = 1_000

    train_test_split = dataset.train_test_split(train_size=TRAIN_SIZE,
                                                seed=RANDOM_SEED)
    train_dataset = train_test_split["train"]
    test_val_dataset = train_test_split["test"].train_test_split(
        train_size=VALID_SIZE, test_size=TEST_SIZE, seed=RANDOM_SEED)
    val_dataset, test_dataset = test_val_dataset["train"], test_val_dataset[
        "test"]

    tokenizer = AutoTokenizer.from_pretrained(args.model_name, use_fast=True)
    datasets = OrderedDict()
    datasets['train'] = TextDataset(train_dataset, tokenizer,
                                    args.max_sequence_length,
                                    not args.disable_sent_tokenize)
    datasets['valid'] = TextDataset(val_dataset, tokenizer,
                                    args.max_sequence_length,
                                    not args.disable_sent_tokenize)
    if args.test:
        datasets['text'] = TextDataset(test_dataset, tokenizer,
                                       args.max_sequence_length,
                                       not args.disable_sent_tokenize)

    print(
        f"Loading {args.model_name} model. Setting {args.trainable_layers} trainable layers."
    )
    encoder = AutoModel.from_pretrained(args.model_name, return_dict=True)
    if not args.train_embeddings:
        for p in encoder.embeddings.parameters():
            p.requires_grad = False
    encoder_layers = encoder.encoder.layer
    if args.trainable_layers > len(encoder_layers):
        warnings.warn(
            f"You are asking to train {args.trainable_layers} layers, but this model has only {len(encoder_layers)}"
        )
    for layer in range(len(encoder_layers) - args.trainable_layers):
        for p in encoder_layers[layer].parameters():
            p.requires_grad = False
    params = dict(vocab_size=datasets['train'].vocab_size,
                  embedding_size=args.embedding_size,
                  rnn_type=args.rnn_type,
                  hidden_size=args.hidden_size,
                  word_dropout=args.word_dropout,
                  embedding_dropout=args.embedding_dropout,
                  latent_size=args.latent_size,
                  num_layers=args.num_layers,
                  bidirectional=args.bidirectional,
                  max_sequence_length=args.max_sequence_length)
    model = SentenceVAE(encoder=encoder, tokenizer=tokenizer, **params)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, expierment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
        json.dump(params, f, indent=4)
    with open(os.path.join(save_model_path, 'train_args.json'), 'w') as f:
        json.dump(vars(args), f, indent=4)

    def kl_anneal_function(anneal_function, step, k, x0):
        if step <= x0:
            return args.initial_kl_weight
        if anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-k * (step - x0 - 2500))))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    NLL = torch.nn.NLLLoss(ignore_index=datasets['train'].pad_idx,
                           reduction='sum')

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k,
                x0):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, KL_loss, KL_weight

    params = [{
        'params': model.encoder.parameters(),
        'lr': args.encoder_learning_rate
    }, {
        'params': [
            *model.decoder_rnn.parameters(), *model.hidden2mean.parameters(),
            *model.hidden2logv.parameters(), *model.latent2hidden.parameters(),
            *model.outputs2vocab.parameters()
        ]
    }]
    optimizer = torch.optim.Adam(params,
                                 lr=args.learning_rate,
                                 weight_decay=args.weight_decay)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=(split == 'train'),
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available(),
                                     collate_fn=DataCollator(tokenizer))

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'],
                                            batch['attention_mask'],
                                            batch['length'])

                # loss calculation
                NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                                                       batch['length'], mean,
                                                       logv,
                                                       args.anneal_function,
                                                       step, args.k, args.x0)

                loss = (NLL_loss + KL_weight * KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['ELBO'] = torch.cat(
                    (tracker['ELBO'], loss.data.view(1, -1)), dim=0)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO" % split.upper(), loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(),
                                      NLL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), NLL_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].tolist(), tokenizer=tokenizer)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs, tracker['ELBO'].mean()))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(),
                                  torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences, the encoded latent space and generated sequences
            if split == 'valid':
                samples, _ = model.inference(z=tracker['z'])
                generated_sents = idx2word(samples.tolist(), tokenizer)
                sents = [{
                    'original': target,
                    'generated': generated
                } for target, generated in zip(tracker['target_sents'],
                                               generated_sents)]
                dump = {'sentences': sents, 'z': tracker['z'].tolist()}
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file, indent=3)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % epoch)
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)
示例#14
0
def main(args):

    #create dir name
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
    ts = ts.replace(':', '-')

    #prepare dataset
    splits = ['train', 'valid'] + (['test'] if args.test else [])

    #create dataset object
    datasets = OrderedDict()

    # create test and train split in data, also preprocess
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    #get training params
    params = dict(vocab_size=datasets['train'].vocab_size,
                  sos_idx=datasets['train'].sos_idx,
                  eos_idx=datasets['train'].eos_idx,
                  pad_idx=datasets['train'].pad_idx,
                  unk_idx=datasets['train'].unk_idx,
                  max_sequence_length=args.max_sequence_length,
                  embedding_size=args.embedding_size,
                  rnn_type=args.rnn_type,
                  hidden_size=args.hidden_size,
                  word_dropout=args.word_dropout,
                  embedding_dropout=args.embedding_dropout,
                  latent_size=args.latent_size,
                  num_layers=args.num_layers,
                  bidirectional=args.bidirectional)

    #init model object
    model = SentenceVAE(**params)

    if torch.cuda.is_available():
        model = model.cuda()

    #logging
    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, expierment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    # make dir
    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    #write params to json and save
    with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
        json.dump(params, f, indent=4)

    #defining function that returns disentangling weight used for KL loss at each input step
    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-k * (step - x0))))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    #defining NLL loss to measure accuracy of the decoding
    NLL = torch.nn.NLLLoss(ignore_index=datasets['train'].pad_idx,
                           reduction='sum')

    #this functiom is used to compute the 2 loss terms and KL loss weight
    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k,
                x0):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor

    step = 0

    for epoch in range(args.epochs):

        #do train and then test
        for split in splits:

            #create dataloader
            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            #tracker used to track the loss
            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            #start batch wise training/testing
            for iteration, batch in enumerate(data_loader):

                #get batch size
                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                                                       batch['length'], mean,
                                                       logv,
                                                       args.anneal_function,
                                                       step, args.k, args.x0)

                # final loss calculation
                loss = (NLL_loss + KL_weight * KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()  #flush grads
                    loss.backward()  #run bp
                    optimizer.step()  #run gd
                    step += 1

                # bookkeepeing
                tracker['ELBO'] = torch.cat(
                    (tracker['ELBO'], loss.data.view(1, -1)), dim=0)

                #logging of losses
                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO" % split.upper(), loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(),
                                      NLL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                #
                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), NLL_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs, tracker['ELBO'].mean()))

            #more logging
            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(),
                                  torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % epoch)
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)
示例#15
0
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid']

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        filename=os.path.join(args.logdir,
                              experiment_name(args, ts) + ".log"))
    logger = logging.getLogger(__name__)

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    total_step = int(args.epochs * 42000.0 / args.batch_size)

    def kl_anneal_function(anneal_function, step):
        if anneal_function == 'half':
            return 0.5
        if anneal_function == 'identity':
            return 1
        if anneal_function == 'double':
            return 2
        if anneal_function == 'quadra':
            return 4

        if anneal_function == 'sigmoid':
            return 1 / (1 + np.exp((0.5 * total_step - step) / 200))

        if anneal_function == 'monotonic':
            beta = step * 4 / total_step
            if beta > 1:
                beta = 1.0
            return beta

        if anneal_function == 'cyclical':
            t = total_step / 4
            beta = 4 * (step % t) / t
            if beta > 1:
                beta = 1.0
            return beta

    ReconLoss = torch.nn.NLLLoss(reduction='sum',
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    train_loss = []
    test_loss = []
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(list)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                # tracker['negELBO'] = torch.cat((tracker['negELBO'], loss.data))
                tracker["negELBO"].append(loss.item())
                tracker["recon_loss"].append(recon_loss.item() / batch_size)
                tracker["KL_Loss"].append(KL_loss.item() / batch_size)
                tracker["KL_Weight"].append(KL_weight)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "\tStep\t%s\t%04d\t%i\t%9.4f\t%9.4f\t%9.4f\t%6.3f" %
                        (split.upper(), iteration, len(data_loader) - 1,
                         loss.item(), recon_loss.item() / batch_size,
                         KL_loss.item() / batch_size, KL_weight))
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), recon_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'].append(z.data.tolist())

            logger.info(
                "\tEpoch\t%s\t%02d\t%i\t%9.4f\t%9.4f\t%9.4f\t%6.3f" %
                (split.upper(), epoch, args.epochs,
                 sum(tracker['negELBO']) / len(tracker['negELBO']),
                 1.0 * sum(tracker['recon_loss']) / len(tracker['recon_loss']),
                 1.0 * sum(tracker['KL_Loss']) / len(tracker['KL_Loss']),
                 1.0 * sum(tracker['KL_Weight']) / len(tracker['KL_Weight'])))
            print("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs,
                   sum(tracker['negELBO']) / len(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar(
                    "%s-Epoch/NegELBO" % split.upper(),
                    1.0 * sum(tracker['negELBO']) / len(tracker['negELBO']),
                    epoch)
                writer.add_scalar(
                    "%s-Epoch/recon_loss" % split.upper(), 1.0 *
                    sum(tracker['recon_loss']) / len(tracker['recon_loss']),
                    epoch)
                writer.add_scalar(
                    "%s-Epoch/KL_Loss" % split.upper(),
                    1.0 * sum(tracker['KL_Loss']) / len(tracker['KL_Loss']),
                    epoch)
                writer.add_scalar(
                    "%s-Epoch/KL_Weight" % split.upper(), 1.0 *
                    sum(tracker['KL_Weight']) / len(tracker['KL_Weight']),
                    epoch)

            if split == 'train':
                train_loss.append(1.0 * sum(tracker['negELBO']) /
                                  len(tracker['negELBO']))
            else:
                test_loss.append(1.0 * sum(tracker['negELBO']) /
                                 len(tracker['negELBO']))
            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z']
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)

    sns.set(style="whitegrid")
    df = pd.DataFrame()
    df["train"] = train_loss
    df["test"] = test_loss
    ax = sns.lineplot(data=df, legend=False)
    ax.set(xlabel='Epoch', ylabel='Loss')
    plt.legend(title='Split', loc='upper right', labels=['Train', 'Test'])
    plt.savefig(os.path.join(args.logdir,
                             experiment_name(args, ts) + ".png"),
                transparent=True,
                dpi=300)
示例#16
0
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    log_file = open("res.txt", "a")
    log_file.write(expierment_name(args, ts))
    log_file.write("\n")
    graph_file = open("elbo-graph.txt", "a")
    graph_file.write(expierment_name(args, ts))
    graph_file.write("\n")

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, expierment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-k * (step - x0))))
        elif anneal_function == 'linear':
            return min(1, step / x0)
        elif anneal_function == "softplus":
            return min(1, np.log(1 + np.exp(k * step)))
        elif anneal_function == "no":
            return 1

    NLL = torch.nn.NLLLoss(size_average=False,
                           ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k,
                x0):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    val_lowest_elbo = 5000
    val_accu_epoch = 0
    val_min_epoch = 0
    split_elbo = {"train": [], "valid": []}
    if args.test:
        split_elbo["test"] = []
    split_loss = {"train": [], "valid": []}
    if args.test:
        split_loss["test"] = []

    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                                                       batch['length'], mean,
                                                       logv,
                                                       args.anneal_function,
                                                       step, args.k, args.x0)

                if split != 'train':
                    KL_weight = 1.0

                loss = (NLL_loss + KL_weight * KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data))

                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO" % split.upper(), loss.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(),
                                      NLL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], NLL_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))
                    split_loss[split].append([
                        loss.data[0], NLL_loss.data[0] / batch_size,
                        KL_loss.data[0] / batch_size
                    ])

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs,
                   torch.mean(tracker['ELBO'])))
            split_elbo[split].append([torch.mean(tracker["ELBO"])])

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(),
                                  torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)

            if split == 'valid':
                if torch.mean(tracker['ELBO']) < val_lowest_elbo:
                    val_lowest_elbo = torch.mean(tracker['ELBO'])
                    val_accu_epoch = 0
                    val_min_epoch = epoch
                else:
                    val_accu_epoch += 1
                    if val_accu_epoch >= 3:
                        if not args.test:
                            exp_str = ""
                            exp_str += "train_ELBO={}\n".format(
                                split_elbo["train"][val_min_epoch])
                            exp_str += "valid_ELBO={}\n".format(
                                split_elbo["valid"][val_min_epoch])
                            exp_str += "==========\n"
                            log_file.write(exp_str)
                            log_file.close()
                            print(exp_str)
                            graph_file.write("ELBO\n")
                            line = ""
                            for s in splits:
                                for i in split_loss[s]:
                                    line += "{},".format(i[0])
                                line += "\n"
                            graph_file.write(line)
                            graph_file.write("NLL\n")
                            line = ""
                            for s in splits:
                                for i in split_loss[s]:
                                    line += "{},".format(i[1])
                                line += "\n"
                            graph_file.write(line)
                            graph_file.write("KL\n")
                            line = ""
                            for s in splits:
                                for i in split_loss[s]:
                                    line += "{},".format(i[2])
                                line += "\n"
                            graph_file.write(line)
                            graph_file.close()
                            exit()
            elif split == 'test' and val_accu_epoch >= 3:
                exp_str = ""
                exp_str += "train_ELBO={}\n".format(
                    split_elbo["train"][val_min_epoch])
                exp_str += "valid_ELBO={}\n".format(
                    split_elbo["valid"][val_min_epoch])
                exp_str += "test_ELBO={}\n".format(
                    split_elbo["test"][val_min_epoch])
                exp_str += "==========\n"
                log_file.write(exp_str)
                log_file.close()
                print(exp_str)
                graph_file.write("ELBO\n")
                line = ""
                for s in splits:
                    for i in split_loss[s]:
                        line += "{},".format(i[0])
                    line += "\n"
                for s in splits:
                    for i in split_elbo[s]:
                        line += "{},".format(i[0])
                    line += "\n"
                graph_file.write(line)
                graph_file.write("NLL\n")
                line = ""
                for s in splits:
                    for i in split_loss[s]:
                        line += "{},".format(i[1])
                    line += "\n"
                graph_file.write(line)
                graph_file.write("KL\n")
                line = ""
                for s in splits:
                    for i in split_loss[s]:
                        line += "{},".format(i[2])
                    line += "\n"
                graph_file.write(line)
                graph_file.close()
                exit()

        if epoch == args.epochs - 1:
            exp_str = ""
            exp_str += "train_ELBO={}\n".format(
                split_elbo["train"][val_min_epoch])
            exp_str += "valid_ELBO={}\n".format(
                split_elbo["valid"][val_min_epoch])
            if args.test:
                exp_str += "test_ELBO={}\n".format(
                    split_elbo["test"][val_min_epoch])
            exp_str += "==========\n"
            log_file.write(exp_str)
            log_file.close()
            print(exp_str)
            graph_file.write("ELBO\n")
            line = ""
            for s in splits:
                for i in split_loss[s]:
                    line += "{},".format(i[0])
                line += "\n"
            graph_file.write(line)
            graph_file.write("NLL\n")
            line = ""
            for s in splits:
                for i in split_loss[s]:
                    line += "{},".format(i[1])
                line += "\n"
            graph_file.write(line)
            graph_file.write("KL\n")
            line = ""
            for s in splits:
                for i in split_loss[s]:
                    line += "{},".format(i[2])
                line += "\n"
            graph_file.write(line)
            graph_file.close()
            exit()
示例#17
0
def main(args):
    #print('start')
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid']

    if args.tensorboard_logging:
        print('Tensorboard logging on')

    w_datasets, y_datasets = load_e2e(args.create_data,
                                      args.max_sequence_length, args.min_occ)
    'datsets loaded'
    print((y_datasets[splits[0]].shape[1]))
    label_sequence_len = y_datasets[splits[0]].shape[1]

    print('lsl')
    print(y_datasets['train'].shape)
    model = SentenceJMVAE(vocab_size=w_datasets['train'].vocab_size,
                          sos_idx=w_datasets['train'].sos_idx,
                          eos_idx=w_datasets['train'].eos_idx,
                          pad_idx=w_datasets['train'].pad_idx,
                          max_sequence_length=args.max_sequence_length,
                          embedding_size=args.embedding_size,
                          rnn_type=args.rnn_type,
                          hidden_size=args.hidden_size,
                          word_dropout=args.word_dropout,
                          latent_size=args.latent_size,
                          num_layers=args.num_layers,
                          label_sequence_len=label_sequence_len,
                          bidirectional=args.bidirectional)
    print('model created')
    if torch.cuda.is_available():
        model = model.cuda()

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join('./', args.logdir, 'JMVAE', expierment_name(args,
                                                                     ts)))
        writer.add_text("model_jmvae", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join('./', args.save_model_path, 'JMVAE', ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-k * (step - x0))))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    NLL = torch.nn.NLLLoss(size_average=False,
                           ignore_index=w_datasets['train'].pad_idx)
    BCE = torch.nn.BCELoss(size_average=False)

    def loss_fn_plus(logp, logp2, target, target2, length, mean, logv, mean_w,
                     logv_w, mean_y, logv_y, anneal_function, step, k, x0):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))
        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)
        NLL_w_avg = NLL_loss / torch.sum(length).float()
        #Cross entropy loss
        BCE_loss = BCE(logp2, target2)
        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())

        KL_loss_w = [
            0.5 * ((sigma0.exp() / sigma1.exp()).sum() + torch.sum(
                ((mu1 - mu0)**2) * (1 / torch.exp(sigma1))) -
                   (mu0.size(0)) + sigma1.sum() - sigma0.sum())
            for mu0, sigma0, mu1, sigma1 in zip(mean, logv, mean_w, logv_w)
        ]
        KL_loss_w = sum(KL_loss_w)  #/len(KL_loss_w)

        KL_loss_y = [
            0.5 * ((sigma0.exp() / sigma1.exp()).sum() + torch.sum(
                ((mu1 - mu0)**2) * (1 / torch.exp(sigma1))) -
                   (mu0.size(0)) + sigma1.sum() - sigma0.sum())
            for mu0, sigma0, mu1, sigma1 in zip(mean, logv, mean_y, logv_y)
        ]
        KL_loss_y = sum(KL_loss_y)  #/len(KL_loss_y)

        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, BCE_loss, KL_loss, KL_loss_w, KL_loss_y, KL_weight, NLL_w_avg

    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    print('starting training')
    for epoch in range(args.epochs):
        for split in splits:
            print('split: ', split, '\tepoch: ', epoch)
            #print(split)
            #print((w_datasets[split][0]))
            #print(w_datasets['train'])

            data_loader = DataLoader(
                dataset=w_datasets[split],  #y_datasets[split],
                batch_size=args.batch_size,
                shuffle=split == 'train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available())
            #print('Out dataloader received')
            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):
                #print('new batch')
                #print('batch')
                batch_size = batch['input'].size(0)
                #print(iteration,batch['labels'])
                batch['labels'] = batch['labels'].float()
                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)
                #print('labels preprocessed')
                # Forward pass
                logp, logp2, mean, logv, z, mean_w, logv_w, mean_y, logv_y = model(
                    batch['input'], batch['labels'], batch['length'])
                #print('forward pass done')
                # loss calculation
                NLL_loss, BCE_loss, KL_loss, KL_loss_w, KL_loss_y, KL_weight, NLL_w_avg = loss_fn_plus(
                    logp, logp2, batch['target'], batch['labels'],
                    batch['length'], mean, logv, mean_w, logv_w, mean_y,
                    logv_y, args.anneal_function, step, args.k, args.x0)
                #!!!!
                # MAYBE ADD WEIGHTS TO KL_W AND KL_Y BASED ON THEIR DIMENSIONALITY
                #!!!
                loss = (NLL_loss + args.bce_weight * BCE_loss + KL_weight *
                        (KL_loss + args.alpha *
                         (KL_loss_w + KL_loss_y))) / batch_size
                #print('loss calculated')

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                #print('backprop done')
                # bookkeepeing

# Avoid the .cat error !!!
#print(loss.data)
#print(tracker['ELBO'])

                loss_data = torch.cuda.FloatTensor([
                    loss.data.item()
                ]) if torch.cuda.is_available() else torch.tensor(
                    [loss.data.item()])
                tracker['ELBO'] = torch.cat(
                    (tracker['ELBO'], loss_data)
                )  #Orig: tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data),1)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO" % split.upper(), loss.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(),
                                      NLL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/BCE Loss" % split.upper(),
                                      BCE_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss-w" % split.upper(),
                                      KL_loss_w.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss-y" % split.upper(),
                                      KL_loss_y.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, BCE-Loss %9.4f, KL-Loss-joint %9.4f, KL-Loss-w %9.4f, KL-Loss-y %9.4f, KL-Weight %6.3f, NLL-word-Loss %9.4f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], NLL_loss.data[0] / batch_size,
                           BCE_loss.data[0] / batch_size, KL_loss.data[0] /
                           batch_size, KL_loss_w.data[0] / batch_size,
                           KL_loss_y.data[0] / batch_size, KL_weight,
                           NLL_w_avg.data[0]))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=w_datasets['train'].get_i2w(),
                        pad_idx=w_datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs,
                   torch.mean(tracker['ELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(),
                                  torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('./dumps', ts)):
                    os.makedirs('./dumps/' + ts)
                with open(
                        os.path.join('./dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w+') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train' and epoch % 10 == 0:
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)
示例#18
0
def generate(date, epoch, sentiment, n_samples):
    date = date
    cuda2 = torch.device('cuda:0')
    epoch = epoch
    #date = "2020-Feb-26-17:47:47"
    #exp_descr = pd.read_csv("EXP_DESCR/" + date + ".csv")
    #print("Pretained: ", exp_descr['pretrained'][0])
    #print("Bidirectional: ", exp_descr['Bidirectional'][0])
    #epoch = str(10)
    #data_dir = 'data'
    #

    params = pd.read_csv("Parameters/params.csv")
    params = params.set_index('time')
    exp_descr = params.loc[date]
    # 2019-Dec-02-09:35:25, 60,300,256,0.3,0.5,16,False,0.001,10,False

    embedding_size = exp_descr["embedding_size"]
    hidden_size = exp_descr["hidden_size"]
    rnn_type = exp_descr['rnn_type']
    word_dropout = exp_descr["word_dropout"]
    embedding_dropout = exp_descr["embedding_dropout"]
    latent_size = exp_descr["latent_size"]
    num_layers = 1
    batch_size = exp_descr["batch_size"]
    bidirectional = bool(exp_descr["bidirectional"])
    max_sequence_length = exp_descr["max_sequence_length"]
    back = exp_descr["back"]
    attribute_size = exp_descr["attr_size"]
    wd_type = exp_descr["word_drop_type"]
    num_samples = 2
    save_model_path = 'bin'
    ptb = False
    if ptb == True:
        vocab_dir = '/ptb.vocab.json'
    else:
        vocab_dir = '/yelp_vocab.json'

    with open("bin/" + date + "/" + vocab_dir, 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(vocab_size=len(w2i),
                        sos_idx=w2i['<sos>'],
                        eos_idx=w2i['<eos>'],
                        pad_idx=w2i['<pad>'],
                        unk_idx=w2i['<unk>'],
                        max_sequence_length=max_sequence_length,
                        embedding_size=embedding_size,
                        rnn_type=rnn_type,
                        hidden_size=hidden_size,
                        word_dropout=0,
                        embedding_dropout=0,
                        latent_size=latent_size,
                        num_layers=num_layers,
                        cuda=cuda2,
                        bidirectional=bidirectional,
                        attribute_size=attribute_size,
                        word_dropout_type='static',
                        back=back)

    print(model)
    # Results
    # 2019-Nov-28-13:23:06/E4-5".pytorch"

    load_checkpoint = "bin/" + date + "/" + "E" + str(epoch) + ".pytorch"
    # load_checkpoint = "bin/2019-Nov-28-12:03:44 /E0.pytorch"

    if not os.path.exists(load_checkpoint):
        raise FileNotFoundError(load_checkpoint)

    if torch.cuda.is_available():
        model = model.cuda()
        device = "cuda"
    else:
        device = "cpu"

    model.load_state_dict(
        torch.load(load_checkpoint, map_location=torch.device(device)))

    def attr_generation(n):
        labels = np.random.randint(2, size=n)
        enc = OneHotEncoder(handle_unknown='ignore')
        labels = np.reshape(labels, (len(labels), 1))
        enc.fit(labels)
        one_hot = enc.transform(labels).toarray()
        one_hot = one_hot.astype(np.float32)
        one_hot = torch.from_numpy(one_hot)
        return one_hot

    model.eval()
    labels = attr_generation(n=num_samples)

    from vaderSentiment.vaderSentiment import SentimentIntensityAnalyzer
    from sklearn.metrics import accuracy_score
    analyser = SentimentIntensityAnalyzer()

    def sentiment_analyzer_scores(sentence):
        score = analyser.polarity_scores(sentence)
        if score['compound'] > 0.05:
            return 1, 'Positive'
        else:
            return 0, 'Negative'

    print('----------SAMPLES----------')
    labels = []
    generated = []
    for i in range(n_samples):
        samples, z, l = model.inference(sentiment)
        s = idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])
        #print(sentiment_analyzer_scores(s[0]))
        if sentiment_analyzer_scores(s[0])[1] == sentiment:
            generated.append(s[0])

        labels.append(sentiment_analyzer_scores(s[0])[0])
        #print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
    print(sum(labels))
    translation = idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])
    return generated
    '''
示例#19
0
def main(args):

    ################ config your params here ########################
    # ortho = False
    # attention = False
    # hspace_classifier = False
    # diversity = False # do not try this yet, need to fix bugs
    
    # create dir name
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
    ts = ts.replace(':', '-')
    ts = ts+'-'+args.dataset

    
    if(args.ortho):
        ts = ts+'-ortho'
    if(args.hspace):
        ts = ts+'-hspace'
    if(args.attention):
        ts = ts+'-self-attn'

   

    if(args.dataset == "multitask"):
        print("Running multitask dataset!")
        vae_model = SentenceVaeMultiTask
        dataset = SnliYelp
    if(args.dataset == "snli"):
        print("Running SNLI!")
        vae_model = SentenceVaeSnli
        dataset = SNLI
    if(args.dataset == "yelp"):
        print("Running Yelp!")
        vae_model = SentenceVaeYelp
        dataset = Yelpd

     # prepare dataset
    splits = ['train', 'test']

    # create dataset object
    datasets = OrderedDict()
    

    # create test and train split in data, also preprocess
    for split in splits:
        print("creating dataset for: {}".format(split))
        datasets[split] = dataset(
            split=split,
            create_data=args.create_data,
            min_occ=args.min_occ
        )

    i2w = datasets['train'].get_i2w()
    w2i = datasets['train'].get_w2i()

    # get training params
    params = dict(
        vocab_size=datasets['train'].vocab_size,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        unk_idx=datasets['train'].unk_idx,
        max_sequence_length=datasets['train'].max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        ortho=args.ortho,
        attention=args.attention,
        hspace_classifier=args.hspace,
        diversity=args.diversity
    )

    # init model object
    model = vae_model(**params)

    if torch.cuda.is_available():
        model = model.cuda()

    # logging
    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(os.path.join(args.logdir, expierment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    # make dir
    save_model_path = os.path.join(datasets["train"].save_model_path, ts)
    os.makedirs(save_model_path)

    # write params to json and save
    with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
        json.dump(params, f, indent=4)

    # defining function that returns disentangling weight used for KL loss at each input step

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1/(1+np.exp(-k*(step-x0))))
        elif anneal_function == 'linear':
            return min(1, step/x0)

    # defining NLL loss to measure accuracy of the decoding
    NLL = torch.nn.NLLLoss(ignore_index=datasets['train'].pad_idx, reduction='sum')

    loss_fn_2 = F.cross_entropy

    # this functiom is used to compute the 2 loss terms and KL loss weight
    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):

        # cut-off unnecessary padding from target, and flatten
       
        target = target[:, :datasets["train"].max_sequence_length].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood        
        NLL_loss = NLL(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
    step = 0

    
    overall_losses = defaultdict(dict)

    loss_at_epoch = {
        'nll_loss': 0.0,
        'kl_loss': 0.0,
        'style_loss': 0.0,
        'content_loss': 0.0,
        'diversity_loss': 0.0,
        'hspace_loss': 0.0,
        'nll_loss_test': 0.0,
        'kl_loss_test': 0.0,
        'style_loss_test': 0.0,
        'content_loss_test': 0.0,
        'diversity_loss_test': 0.0,
        'hspace_loss_test': 0.0
    }

    for epoch in range(args.epochs):

        # do train and then test
        for split in splits:

            # create dataloader
            data_loader = DataLoader(
                dataset=datasets[split],
                batch_size=args.batch_size,
                shuffle=split == 'train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available()
            )

            # tracker used to track the loss
            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            # start batch wise training/testing
            for iteration, batch in enumerate(data_loader):

                # get batch size
                batch_size = batch['input'].size(0)

               
                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # try sample
                # print(idx2word(batch['target'][0:1], i2w=i2w, pad_idx=w2i['<pad>']))
                # print(batch['label'][0])
                # continue
                # print("neg: {}, pos: {}".format(style_preds[0:1,0], style_preds[0:1,1]))

                # Forward pass
                logp, final_mean, final_logv, final_z, style_preds, content_preds, hspace_preds, diversity_loss = model(batch['input'], batch['length'], batch['label'], batch['bow'])

                # loss calculation
                NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'], batch['length'], final_mean, final_logv, args.anneal_function, step, args.k, args.x0)
        
                style_loss = nn.MSELoss()(style_preds, batch['label'].type(torch.FloatTensor).cuda()) #classification loss
                content_loss = nn.MSELoss()(content_preds, batch['bow'].type(torch.FloatTensor).cuda()) #classification loss

                if(hspace_preds is None):
                    hspace_classifier_loss = 0
                else:
                    hspace_classifier_loss = nn.MSELoss()(hspace_preds, batch['label'].type(torch.FloatTensor).cuda()) 

                # final loss calculation
                loss = (NLL_loss + KL_weight * KL_loss) / batch_size + 1000 * style_loss + 1000*content_loss
                # loss = (NLL_loss + KL_weight * KL_loss) / batch_size 

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()  # flush grads
                    
                    if(args.diversity):
                        loss.backward(retain_graph = True)  # run bp
                        diversity_loss.backward()
                    else:
                        loss.backward()  # run bp

                    optimizer.step()  # run gd
                    step += 1

                

                overall_losses[len(overall_losses)] = loss_at_epoch

                # bookkeeping
                tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data.view(1, -1)), dim=0)

                # logging of losses
                if args.tensorboard_logging:
                    writer.add_scalar(
                        "%s/ELBO" % split.upper(), loss.item(), epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(), NLL_loss.item() / batch_size,
                                      epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(), KL_loss.item() / batch_size,
                                      epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(), KL_weight,
                                      epoch*len(data_loader) + iteration)
                                      

                #
                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                   print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f, Style-Loss %9.4f, Content-Loss %9.4f, Hspace-Loss %9.4f, Diversity-Loss %9.4f"
                          % (split.upper(), iteration, len(data_loader)-1, loss.item(), NLL_loss.item()/batch_size,
                             KL_loss.item()/batch_size, KL_weight, style_loss, content_loss, hspace_classifier_loss, diversity_loss))
                    
  

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(),pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs, tracker['ELBO'].mean()))
            
             

            # more logging
            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(),
                                  torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'], 'z': tracker['z'].tolist()}
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/'+ts)
                with open(os.path.join('dumps/'+ts+'/valid_E%i.json' % epoch), 'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(
                    save_model_path, "E%i.pytorch" % epoch)
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)
            
            # update losses log
            if(split == "train"):
                loss_at_epoch['nll_loss'] = float(NLL_loss/args.batch_size)
                loss_at_epoch['kl_loss'] = float(KL_loss)
                loss_at_epoch['style_loss'] = float(style_loss)
                loss_at_epoch['content_loss'] = float(content_loss)
                loss_at_epoch['diversity_loss'] = float(diversity_loss)
                loss_at_epoch['hspace_loss'] = float(hspace_classifier_loss)
            else:
                loss_at_epoch['nll_loss_test'] = float(NLL_loss/args.batch_size)
                loss_at_epoch['kl_loss_test'] = float(KL_loss)
                loss_at_epoch['style_loss_test'] = float(style_loss)
                loss_at_epoch['content_loss_test'] = float(content_loss)
                loss_at_epoch['diversity_loss_test'] = float(diversity_loss)
                loss_at_epoch['hspace_loss_test'] = float(hspace_classifier_loss)
        
    # write losses to json
    with open(os.path.join(save_model_path, 'losses.json'), 'w') as f:
        json.dump(overall_losses, f, indent=4)
示例#20
0
def main(args):

    data_name = args.data_name
    with open(args.data_dir+data_name+'.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s"%(args.load_checkpoint))

    if torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()

    # samples, z = model.inference(n=args.num_samples)
    # print('----------SAMPLES----------')
    # print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
    
    # z1 = torch.randn([args.latent_size]).numpy()
    # z2 = torch.randn([args.latent_size]).numpy()
    # z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    # samples, _ = model.inference(z=z)
    # print('-------INTERPOLATION-------')
    # print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    # print('-------Encode ... Decode-------')
    
    # datasets = Amazon(
    #         data_dir=args.data_dir,
    #         split="valid",
    #         create_data=False,
    #         batch_size=10,
    #         max_sequence_length=args.max_sequence_length,
    #         min_occ=3
    #     )


    ### load vocab
    # with open(os.path.join(args.data_dir, args.vocab_file), 'r') as file:
    #     vocab = json.load(file)
    #     w2i, i2w = vocab['w2i'], vocab['i2w']

    tokenizer = TweetTokenizer(preserve_case=False)

    # raw_text = "I like this!"
    raw_text = "DON'T CARE FOR IT.  GAVE IT AS A GIFT AND THEY WERE OKAY WITH IT.  JUST NOT WHAT I EXPECTED."
    input_text = f_raw2vec(tokenizer, raw_text, w2i, i2w)
    length_text = len(input_text)
    length_text = [length_text]
    print("length_text", length_text)

    input_tensor = torch.LongTensor(input_text)
    print('input_tensor', input_tensor)
    input_tensor = input_tensor.unsqueeze(0)
    if torch.is_tensor(input_tensor):
        input_tensor = to_var(input_tensor)

    length_tensor = torch.LongTensor(length_text)
    print("length_tensor", length_tensor)
    # length_tensor = length_tensor.unsqueeze(0)
    if torch.is_tensor(length_tensor):
        length_tensor = to_var(length_tensor)
    
    print("*"*10)
    print("->"*10, *idx2word(input_tensor, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
    logp, mean, logv, z = model(input_tensor, length_tensor)

    # print("z", z.size(), mean_z.size())
    mean = mean.unsqueeze(0)
    print("mean", mean)
    print("z", z)

    samples, z = model.inference(z=mean)
    print("<-"*10, *idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    for i in range(10):
        samples, z = model.inference(z=z)
        print("<-"*10, *idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
def main(args):

	ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

	splits = ['train', 'valid'] + (['test'] if args.test else [])

	datasets = OrderedDict()
	for split in splits:

		if args.dataset == 'ptb':
			Dataset = PTB
		elif args.dataset == 'twitter':
			Dataset = PoliticianTweets
		else:
			print("Invalid dataset. Exiting")
			exit()

		datasets[split] = Dataset(
			data_dir=args.data_dir,
			split=split,
			create_data=args.create_data,
			max_sequence_length=args.max_sequence_length,
			min_occ=args.min_occ
		)

	model = SentenceVAE(
		vocab_size=datasets['train'].vocab_size,
		sos_idx=datasets['train'].sos_idx,
		eos_idx=datasets['train'].eos_idx,
		pad_idx=datasets['train'].pad_idx,
		unk_idx=datasets['train'].unk_idx,
		max_sequence_length=args.max_sequence_length,
		embedding_size=args.embedding_size,
		rnn_type=args.rnn_type,
		hidden_size=args.hidden_size,
		word_dropout=args.word_dropout,
		embedding_dropout=args.embedding_dropout,
		latent_size=args.latent_size,
		num_layers=args.num_layers,
		bidirectional=args.bidirectional
		)

	# if args.from_file != "":
	# 	model = torch.load(args.from_file)
	#

	if torch.cuda.is_available():
		model = model.cuda()

	print(model)

	if args.tensorboard_logging:
		writer = SummaryWriter(os.path.join(args.logdir, experiment_name(args,ts)))
		writer.add_text("model", str(model))
		writer.add_text("args", str(args))
		writer.add_text("ts", ts)

	save_model_path = os.path.join(args.save_model_path, ts)
	os.makedirs(save_model_path)

	
	if 'sigmoid' in args.anneal_function and args.dataset=='ptb':
		linspace = np.linspace(-5,5,13160) # 13160 = number of training examples in ptb
	elif 'sigmoid' in args.anneal_function and args.dataset=='twitter':
		linspace = np.linspace(-5, 5, 25190) #6411/25190? = number of training examples in short version of twitter

	def kl_anneal_function(anneal_function, step, param_dict=None):
		if anneal_function == 'identity':
			return 1
		elif anneal_function == 'sigmoid' or anneal_function=='sigmoid_klt':
			s = 1/(len(linspace))
			return(float((1)/(1+np.exp(-param_dict['ag']*(linspace[step])))))

	NLL = torch.nn.NLLLoss(size_average=False, ignore_index=datasets['train'].pad_idx)
	def loss_fn(logp, target, length, mean, logv, anneal_function, step, param_dict=None):

		# cut-off unnecessary padding from target, and flatten
		target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
		logp = logp.view(-1, logp.size(2))
		
		# Negative Log Likelihood
		NLL_loss = NLL(logp, target)

		# KL Divergence
		KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
		if args.anneal_function == 'sigmoid_klt':
			if float(KL_loss)/args.batch_size < param_dict['kl_threshold']:
				# print("KL_loss of %s is below threshold %s. Returning this threshold instead"%(float(KL_loss)/args.batch_size,param_dict['kl_threshold']))
				KL_loss = to_var(torch.Tensor([param_dict['kl_threshold']*args.batch_size]))
		KL_weight = kl_anneal_function(anneal_function, step, {'ag': args.anneal_aggression})

		return NLL_loss, KL_loss, KL_weight

	optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

	tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
	step = 0
	for epoch in range(args.epochs):

		for split in splits:

			data_loader = DataLoader(
				dataset=datasets[split],
				batch_size=args.batch_size,
				shuffle=split=='train',
				num_workers=0,
				pin_memory=torch.cuda.is_available()
			)

			tracker = defaultdict(tensor)

			# Enable/Disable Dropout
			if split == 'train':
				model.train()
			else:
				model.eval()

			for iteration, batch in enumerate(data_loader):

				batch_size = batch['input'].size(0)
				if split == 'train' and batch_size != args.batch_size:
					print("WARNING: Found different batch size\nargs.batch_size= %s, input_size=%s"%(args.batch_size, batch_size))
					

				for k, v in batch.items():
					if torch.is_tensor(v):
						batch[k] = to_var(v)

				# Forward pass
				logp, mean, logv, z = model(batch['input'], batch['length'])

				# loss calculation
				NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
					batch['length'], mean, logv, args.anneal_function, step, {'kl_threshold': args.kl_threshold})

				loss = (NLL_loss + KL_weight * KL_loss)/batch_size

				# backward + optimization
				if split == 'train':
					optimizer.zero_grad()
					loss.backward()
					optimizer.step()
					step += 1
					# print(step)

				# bookkeepeing
				tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data))

				
				if args.tensorboard_logging:
					writer.add_scalar("%s/ELBO"%split.upper(), loss.data[0], epoch*len(data_loader) + iteration)
					writer.add_scalar("%s/NLL_Loss"%split.upper(), NLL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
					writer.add_scalar("%s/KL_Loss"%split.upper(), KL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
					# print("Step %s: %s"%(epoch*len(data_loader) + iteration, KL_weight))
					writer.add_scalar("%s/KL_Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)

				if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
					logger.info("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
						%(split.upper(), iteration, len(data_loader)-1, loss.data[0], NLL_loss.data[0]/batch_size, KL_loss.data[0]/batch_size, KL_weight))

				if split == 'valid':
					if 'target_sents' not in tracker:
						tracker['target_sents'] = list()
					tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx)
					tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

			logger.info("%s Epoch %02d/%i, Mean ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['ELBO'])))

			if args.tensorboard_logging:
				writer.add_scalar("%s-Epoch/ELBO"%split.upper(), torch.mean(tracker['ELBO']), epoch)

			# save a dump of all sentences and the encoded latent space
			if split == 'valid':
				dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
				if not os.path.exists(os.path.join('dumps', ts)):
					os.makedirs('dumps/'+ts)
				with open(os.path.join('dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
					json.dump(dump,dump_file)

			# save checkpoint
			if split == 'train':
				checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
				torch.save(model.state_dict(), checkpoint_path)
				logger.info("Model saved at %s"%checkpoint_path)

	torch.save(model, f"model-{args.dataset}-{ts}.pickle")
示例#22
0
文件: train.py 项目: pvijayak/RNNVaE
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, x1, x2):
        if anneal_function == 'identity':
            return 1
        elif anneal_function == 'linear':
            return min(1, step / x1)
        elif anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-x2 * (step - x1))))
        elif anneal_function == 'cyclic_log':
            return float(1 / (1 + np.exp(-x2 * ((step % (3 * x1)) - x1))))
        elif anneal_function == 'cyclic_lin':
            return min(1, (step % (3 * x1)) / x1)

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, x1,
                x2):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, x1, x2)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0

    early_stopping = EarlyStopping(history=10)
    for epoch in range(args.epochs):

        early_stopping_flag = False
        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            # tracker = defaultdict(tensor)
            tracker = defaultdict(list)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, args.x1, args.x2)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'].append(loss.item())

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    # print(step)
                    # logger.info("Step = %d"%step)
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), recon_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    # tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)
                    # print(z.data.shape)
                    tracker['z'].append(z.data.tolist())
            mean_loss = sum(tracker['negELBO']) / len(tracker['negELBO'])

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs, mean_loss))
            # print(mean_loss)

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  mean_loss, epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z']
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)
                if (args.early_stopping):
                    if (early_stopping.check(mean_loss)):
                        early_stopping_flag = True

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)

        if (early_stopping_flag):
            print("Early stopping trigerred. Training stopped...")
            break
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def sigmoid(step):
        x = step - 6569.5
        if x < 0:
            a = np.exp(x)
            res = (a / (1 + a))
        else:
            res = (1 / (1 + np.exp(-x)))
        return float(res)

    def frange_cycle_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio=0.5):
        L = np.ones(n_iter) * stop
        period = n_iter / n_cycle
        step = (stop - start) / (period * ratio)  # linear schedule

        for c in range(n_cycle):
            v, i = start, 0
            while v <= stop and (int(i + c * period) < n_iter):
                L[int(i + c * period)] = v
                v += step
                i += 1
        return L

    n_iter = 0
    for epoch in range(args.epochs):
        split = 'train'
        data_loader = DataLoader(dataset=datasets[split],
                                 batch_size=args.batch_size,
                                 shuffle=split == 'train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())

        for iteration, batch in enumerate(data_loader):
            n_iter += 1
    print("Total no of iterations = " + str(n_iter))

    L = frange_cycle_linear(n_iter)

    def kl_anneal_function(anneal_function, step):
        if anneal_function == 'identity':
            return 1

        if anneal_function == 'sigmoid':
            return sigmoid(step)

        if anneal_function == 'cyclic':
            return float(L[step])

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp,
                target,
                length,
                mean,
                logv,
                anneal_function,
                step,
                split='train'):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        if split == 'train':
            KL_weight = kl_anneal_function(anneal_function, step)
        else:
            KL_weight = 1

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, split)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'] = torch.cat((tracker['negELBO'], loss.data))

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], recon_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs,
                         torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  torch.mean(tracker['negELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)
示例#24
0
def main(args):

    with open(args.data_dir + '/ptb/ptb.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    # load params
    params = load_model_params_from_checkpoint(args.load_params)

    # create model
    model = SentenceVAE(**params)

    print(model)
    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s" % args.load_checkpoint)
    # splits = ['train', 'test']
    splits = ['test']

    if torch.cuda.is_available():
        model = model.cuda()

    model.eval()
    datasets = OrderedDict()
    tsne_values = np.empty((0, 256), int)
    tsne_labels = np.empty((0, 2), int)

    for split in splits:
        print("creating dataset for: {}".format(split))
        datasets[split] = PTB(split=split,
                              create_data=args.create_data,
                              min_occ=args.min_occ)

    total_bleu = 0.0
    total_iterations = 0

    for split in splits:

        # create dataloader
        data_loader = DataLoader(dataset=datasets[split],
                                 batch_size=args.batch_size,
                                 shuffle=split == 'train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())
        for iteration, batch in enumerate(data_loader):

            # get batch size
            batch_size = batch['input'].size(0)

            for k, v in batch.items():
                if torch.is_tensor(v):
                    batch[k] = to_var(v)

            logp = model.bleu(batch['input'], batch['length'])
            gen_sents = idx2word(logp, i2w=i2w, pad_idx=w2i['<pad>'])
            batch_sents = idx2word(batch['input'],
                                   i2w=i2w,
                                   pad_idx=w2i['<pad>'])
            generated = [line.strip().split() for line in gen_sents]
            actual = [line.strip().split() for line in batch_sents]
            all_actual = [actual for i in range(len(generated))]
            bleus = nltk.translate.bleu_score.corpus_bleu(
                all_actual, generated)
            total_bleu = total_bleu + bleus
            total_iterations = iteration + 1
            # if iteration==:
            #     break

        bleu_score = total_bleu / total_iterations
        print(bleu_score)
def main(args):

    data_name = args.data_name
    with open(args.data_dir+data_name+'.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    model = SentenceVAE(
        vocab_size=len(w2i),
        sos_idx=w2i['<sos>'],
        eos_idx=w2i['<eos>'],
        pad_idx=w2i['<pad>'],
        unk_idx=w2i['<unk>'],
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    if not os.path.exists(args.load_checkpoint):
        raise FileNotFoundError(args.load_checkpoint)

    model.load_state_dict(torch.load(args.load_checkpoint))
    print("Model loaded from %s"%(args.load_checkpoint))

    if torch.cuda.is_available():
        model = model.cuda()
    
    model.eval()

    samples, z = model.inference(n=args.num_samples)
    print('----------SAMPLES----------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
    
    z1 = torch.randn([args.latent_size]).numpy()
    z2 = torch.randn([args.latent_size]).numpy()
    z = to_var(torch.from_numpy(interpolate(start=z1, end=z2, steps=8)).float())
    samples, _ = model.inference(z=z)
    print('-------INTERPOLATION-------')
    print(*idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')

    print('-------Encode ... Decode-------')
    
    datasets = Amazon(
            data_dir=args.data_dir,
            split="valid",
            create_data=False,
            batch_size=10,
            max_sequence_length=args.max_sequence_length,
            min_occ=3
        )

    iteration = 0
    for input_batch_tensor, target_batch_tensor, length_batch_tensor in datasets:
        if torch.is_tensor(input_batch_tensor):
            input_batch_tensor = to_var(input_batch_tensor)

        if torch.is_tensor(target_batch_tensor):
            target_batch_tensor = to_var(target_batch_tensor)

        if torch.is_tensor(length_batch_tensor):
            length_batch_tensor = to_var(length_batch_tensor)

        print("*"*10)
        print("->"*10, *idx2word(input_batch_tensor, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
        logp, mean, logv, z = model(input_batch_tensor,length_batch_tensor)

        
        samples, z = model.inference(z=z)
        print("<-"*10, *idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>']), sep='\n')
        # print("+"*10)
        if iteration == 0:
            break

        iteration += 1
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    # The Twitter user who we want to get tweets from
    BS = "SenSanders"
    EW = "ewarren"
    AOC = "AOC"
    HC = "HillaryClinton"
    MM = "senatemajldr"
    LG = "LindseyGrahamSC"
    DT = "realDonaldTrump"
    DD = "SenatorDurbin"
    JM = "Sen_JoeManchin"
    JT = "SenatorTester"
    MR = "MittRomney"
    KM = "GOPLeader"
    DC = "RepDougCollins"
    CS = "SenSchumer"
    CB = "cbellantoni"
    EE = "ewerickson"
    MF = "mindyfinn"
    GG = "ggreenwald"
    NP = "nicopitney"
    TC = "TPCarney"
    AC = "anamariecox"
    DB = "donnabrazile"
    TCar = "TuckerCarlson"
    politicians = [
        BS, EW, AOC, HC, MM, LG, DT, DD, JM, JT, MR, KM, DC, CS, CB, EE, MF,
        GG, NP, TC, AC, DB, TCar
    ]
    partial_splits = ["test." + pol for pol in politicians]

    # Test splits are generated from GetTweets.py
    splits = ["train", "valid"] + partial_splits

    datasets = OrderedDict()
    for split in splits:

        if args.dataset == 'ptb':
            Dataset = PTB
        elif args.dataset == 'twitter':
            Dataset = PoliticianTweets
        else:
            print("Invalid dataset. Exiting")
            exit()

        datasets[split] = Dataset(data_dir=args.data_dir,
                                  split=split,
                                  create_data=args.create_data,
                                  max_sequence_length=args.max_sequence_length,
                                  min_occ=args.min_occ)

    # Must specify the pickle file from which to load the model
    if args.from_file != "":
        model = torch.load(args.from_file)
        checkpoint = torch.load(
            "/home/jakemdaly/PycharmProjects/vae/pa2/Language-Modelling-CSE291-AS2/bin/2020-May-26-06:03:46/E2.pytorch"
        )
        model.load_state_dict(checkpoint)
        print("Model loaded from file.")
    else:
        print(
            "Must be initialized with a pretrained model/pickle file. Exiting..."
        )
        exit()

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    # These are the dictionaries that get dumped to json
    PoliticianSentences = {}
    PoliticianLatents = {}
    for split in splits[2:]:
        PoliticianLatents[split] = None
        PoliticianSentences[split] = None

    for epoch in range(args.epochs):

        for split in splits[2:]:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=0,
                                     pin_memory=torch.cuda.is_available())

            # Enable/Disable Dropout
            if split == 'train' or split == 'valid':
                continue
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                # Get kth batch
                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Latent variables
                try:
                    if PoliticianLatents[split] is None:
                        PoliticianLatents[split] = get_latent(
                            model, batch['input'],
                            batch['length']).data.numpy()
                    else:
                        # pesky situation: if it has only one dimension, we need to unsqueeze it so it can be appended.
                        if len(
                                np.shape(
                                    get_latent(model, batch['input'],
                                               batch['length']).data.
                                    numpy())) == 1 and np.shape(
                                        get_latent(model, batch['input'],
                                                   batch['length']).data.numpy(
                                                   ))[0] == args.latent_size:
                            PoliticianLatents[split] = np.append(
                                PoliticianLatents[split],
                                np.expand_dims(
                                    get_latent(model, batch['input'],
                                               batch['length']).data.numpy(),
                                    0),
                                axis=0)
                        else:
                            PoliticianLatents[split] = np.append(
                                PoliticianLatents[split],
                                get_latent(model, batch['input'],
                                           batch['length']).data.numpy(),
                                axis=0)
                except:
                    print(split)
                # # Sentences corresponding to the latent mappings above
                if PoliticianSentences[split] is None:
                    PoliticianSentences[split] = idx2word(
                        batch['input'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                else:
                    PoliticianSentences[split].append(
                        idx2word(batch['input'].data,
                                 i2w=datasets['train'].get_i2w(),
                                 pad_idx=datasets['train'].pad_idx))

    # Dump all data to json files for analysis
    for split in splits[2:]:
        PoliticianLatents[split] = PoliticianLatents[split].tolist()
    with open("PoliticianLatentsE2.json", 'w') as file:
        json.dump(PoliticianLatents, file)
    with open("PoliticianSentences.json", 'w') as file:
        json.dump(PoliticianSentences, file)
    print("Done.")
示例#27
0
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] #+ (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(
            data_dir=args.data_dir,
            split=split,
            create_data=args.create_data,
            max_sequence_length=args.max_sequence_length,
            min_occ=args.min_occ
        )

    model = SentenceVAE(
        vocab_size=datasets['train'].vocab_size,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )

    if torch.cuda.is_available():
        model = model.cuda()

    if args.tensorboard_logging:
        writer = SummaryWriter(os.path.join('./',args.logdir, expierment_name(args,ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join('./',args.save_model_path,'VAE', ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1/(1+np.exp(-k*(step-x0))))
        elif anneal_function == 'linear':
            return min(1, step/x0)

    NLL = torch.nn.NLLLoss(size_average=False, ignore_index=datasets['train'].pad_idx)
    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))
        
        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)
        NLL_w_avg = NLL_loss/torch.sum(length).float()

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, KL_loss, KL_weight,NLL_w_avg
    print(model)
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(
                dataset=datasets[split],
                batch_size=args.batch_size,
                shuffle=split=='train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available()
            )

            tracker = defaultdict(tensor)
 
            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                NLL_loss, KL_loss, KL_weight,NLL_w_avg = loss_fn(logp, batch['target'],
                    batch['length'], mean, logv, args.anneal_function, step, args.k, args.x0)

                loss = (NLL_loss + KL_weight * KL_loss)/batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1


                # bookkeepeing
		# Avoid the .cat error !!!
                #print(loss.data)
                #print(tracker['ELBO'])
                loss_data = torch.tensor([loss.data.item()])
                tracker['ELBO'] = torch.cat((tracker['ELBO'], loss_data)) #Orig: tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data),1)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO"%split.upper(), loss.data[0], epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss"%split.upper(), NLL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss"%split.upper(), KL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                    print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f, NLL-word-Loss %9.4f"
                        %(split.upper(), iteration, len(data_loader)-1, loss.data[0], NLL_loss.data[0]/batch_size, KL_loss.data[0]/batch_size, KL_weight,NLL_w_avg))
                
                #split = 'invalid' #JUST TO DEBUG!!!
                
                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx) #ERROR HERE!!!
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['ELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO"%split.upper(), torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
                if not os.path.exists(os.path.join('./dumps', ts)):
                    os.makedirs('dumps/'+ts)
                with open(os.path.join('./dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
                    json.dump(dump,dump_file)

            # save checkpoint
            if split == 'train' and epoch %10 ==0 :
                checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s"%checkpoint_path)
示例#28
0
    with torch.no_grad():
        for i in xrange(N_minibatches):
            # sample from z ~ N(0,1)
            z = torch.randn(args.batch_size, train_args.z_dim)
            z = z.to(device)

            seq, length = model.program_decoder.sample(z,
                                                       train_args.max_seq_len,
                                                       greedy=False)
            label = model.label_decoder(z)

            # stop storing these on GPU
            seq, label = seq.cpu(), label.cpu()

            # convert programs to strings
            seq = idx2word(seq, i2w=i2w, pad_idx=w2i[PAD_TOKEN])

            # convert labels to strings
            label = [
                tensor_to_labels(t, label_dim, ix_to_label) for t in label
            ]

            programs.extend(seq)
            labels.extend(label)

            pbar.update()

        if N_leftover > 0:
            z = torch.randn(N_leftover, train_args.z_dim)
            z = z.to(device)
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    curBest = 1000000
    for split in splits:
        datasets[split] = Mixed(data_dir=args.data_dir,
                                split=split,
                                create_data=args.create_data,
                                max_sequence_length=args.max_sequence_length,
                                min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, totalIterations, split):
        if (split != 'train'):
            return 1
        elif anneal_function == 'identity':
            return 1
        elif anneal_function == 'linear':
            return 1.005 * float(step) / totalIterations
        elif anneal_function == 'sigmoid':
            return (1 / (1 + math.exp(-8 * (float(step) / totalIterations))))
        elif anneal_function == 'tanh':
            return math.tanh(4 * (float(step) / totalIterations))
        elif anneal_function == 'linear_capped':
            #print(float(step)*30/totalIterations)
            return min(1.0, float(step) * 5 / totalIterations)
        elif anneal_function == 'cyclic':
            quantile = int(totalIterations / 5)
            remainder = int(step % quantile)
            midPoint = int(quantile / 2)
            if (remainder > midPoint):
                return 1
            else:
                return float(remainder) / midPoint
        else:
            return 1

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step,
                totalIterations, split):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        #print((1 + logv - mean.pow(2) - logv.exp()).size())

        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        #print(KL_loss.size())
        KL_weight = kl_anneal_function(anneal_function, step, totalIterations,
                                       split)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    tensor2 = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    tensor3 = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    tensor4 = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor

    step = 0
    stop = False
    Z = []
    L = []
    for epoch in range(args.epochs):
        if (stop):
            break
        for split in splits:
            if (split == 'test'):
                z_data = []
                domain_label = []
                z_bool = False
                domain_label_bool = False
            if (stop):
                break
            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            totalIterations = (int(len(datasets[split]) / args.batch_size) +
                               1) * args.epochs

            tracker = defaultdict(tensor)
            tracker2 = defaultdict(tensor2)
            tracker3 = defaultdict(tensor3)
            tracker4 = defaultdict(tensor4)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):
                #                 if(iteration > 400):
                #                     break
                batch_size = batch['input'].size(0)
                labels = batch['label']

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])
                if (split == 'test'):
                    if (z_bool == False):
                        z_bool = True
                        domain_label = labels.tolist()
                        z_data = z
                    else:
                        domain_label += labels.tolist()
                        #print(domain_label)
                        z_data = torch.cat((z_data, z), 0)

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, totalIterations, split)

                if split == 'train':
                    #KL_loss_thresholded = torch.clamp(KL_loss, min=6.0)
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'] = torch.cat((tracker['negELBO'], loss.data))
                tracker2['KL_loss'] = torch.cat(
                    (tracker2['KL_loss'], KL_loss.data))
                tracker3['Recon_loss'] = torch.cat(
                    (tracker3['Recon_loss'], recon_loss.data))
                tracker4['Perplexity'] = torch.cat(
                    (tracker4['Perplexity'],
                     torch.exp(recon_loss.data / batch_size)))

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], recon_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))

                if (split == 'test'):
                    Z = z_data
                    L = domain_label

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs,
                         torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  torch.mean(tracker['negELBO']), epoch)
                writer.add_scalar("%s-Epoch/KL_loss" % split.upper(),
                                  torch.mean(tracker2['KL_loss']) / batch_size,
                                  epoch)
                writer.add_scalar(
                    "%s-Epoch/Recon_loss" % split.upper(),
                    torch.mean(tracker3['Recon_loss']) / batch_size, epoch)
                writer.add_scalar("%s-Epoch/Perplexity" % split.upper(),
                                  torch.mean(tracker4['Perplexity']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                if (torch.mean(tracker['negELBO']) < curBest):
                    curBest = torch.mean(tracker['negELBO'])
                else:
                    stop = True
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps_32_0', ts)):
                    os.makedirs('dumps_32_0/' + ts)
                with open(
                        os.path.join('dumps_32_0/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            # if split == 'train':
            #     checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
            #     torch.save(model.state_dict(), checkpoint_path)
            #     logger.info("Model saved at %s"%checkpoint_path)

    Z = Z.data.cpu().numpy()
    print(Z.shape)
    beforeTSNE = TSNE(random_state=20150101).fit_transform(Z)
    scatter(beforeTSNE, L, [0, 1, 2], (5, 5), 'latent discoveries')
    plt.savefig('mixed_tsne' + args.anneal_function + '.png', dpi=120)
示例#30
0
def main(args):

    # Load the vocab
    with open(args.data_dir+'/ptb.vocab.json', 'r') as file:
        vocab = json.load(file)

    w2i, i2w = vocab['w2i'], vocab['i2w']

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    # Initialize semantic loss
    sl = Semantic_Loss()

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(
            data_dir=args.data_dir,
            split=split,
            create_data=args.create_data,
            max_sequence_length=args.max_sequence_length,
            min_occ=args.min_occ
        )

    params = dict(
        vocab_size=datasets['train'].vocab_size,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        unk_idx=datasets['train'].unk_idx,
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
    )
    model = SentenceVAE(**params)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(os.path.join(args.logdir, expierment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
        json.dump(params, f, indent=4)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1/(1+np.exp(-k*(step-x0))))
        elif anneal_function == 'linear':
            return min(1, step/x0)

    def perplexity_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1/ 1+np.exp(-k*(step-x0)))
        elif anneal_function == 'linear':
            return min(1, (step/x0))

    NLL = torch.nn.NLLLoss(ignore_index=datasets['train'].pad_idx, reduction='sum')
    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k, x0, \
        batch_perplexity, perplexity_anneal_function):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)

        # Perplexity
        perp_loss = batch_perplexity
        perp_weight = perplexity_anneal_function(anneal_function, step, k, x0)

        return NLL_loss, KL_loss, KL_weight, perp_loss, perp_weight


    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        # Keep track of epoch loss
        epoch_loss = []

        for split in splits:

            data_loader = DataLoader(
                dataset=datasets[split],
                batch_size=args.batch_size,
                shuffle=split=='train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available()
            )

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            batch_t_start = None

            for iteration, batch in enumerate(data_loader):

                if batch_t_start:
                    batch_run_time = time.time() - batch_t_start
                    # print("Batch run time: " + str(batch_run_time))
                batch_t_start = time.time()


                batch_size = batch['input_sequence'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Get the original sentences in this batch
                batch_sentences = idx2word(batch['input_sequence'], i2w=i2w, pad_idx=w2i['<pad>'])
                # Remove the first tag
                batch_sentences = [x.replace("<sos>", "") for x in batch_sentences]

                # Forward pass
                (logp, mean, logv, z), states = model(**batch)


                # Choose some random pairs of samples within the batch
                #  to get latent representations for
                batch_index_pairs = list(itertools.combinations(np.arange(batch_size), 2))
                random.shuffle(batch_index_pairs)
                batch_index_pairs = batch_index_pairs[:args.perplexity_samples_per_batch]

                batch_perplexity = []

                # If we start the perplexity
                start_perplexity = epoch > 10

                # If we should have perplexity loss
                if start_perplexity and args.perplexity_loss:
                    # For each pair, get the intermediate representations in the latent space
                    for index_pair in batch_index_pairs:

                        with torch.no_grad():
                            z1_hidden = states['z'][index_pair[0]].cpu()
                            z2_hidden = states['z'][index_pair[1]].cpu()

                        z_hidden = to_var(torch.from_numpy(interpolate(start=z1_hidden, end=z2_hidden, steps=1)).float())

                        if args.rnn_type == "lstm":

                            with torch.no_grad():
                                z1_cell_state = states['z_cell_state'].cpu().squeeze()[index_pair[0]]
                                z2_cell_state = states['z_cell_state'].cpu().squeeze()[index_pair[1]]

                            z_cell_states = \
                                to_var(torch.from_numpy(interpolate(start=z1_cell_state, end=z2_cell_state, steps=1)).float())

                            samples, _ = model.inference(z=z_hidden, z_cell_state=z_cell_states)
                        else:
                            samples, _ = model.inference(z=z_hidden, z_cell_state=None)

                        # Check interpolated sentences
                        interpolated_sentences = idx2word(samples, i2w=i2w, pad_idx=w2i['<pad>'])
                        # For each sentence, get the perplexity and show it
                        perplexities = []
                        for sentence in interpolated_sentences:
                            perplexities.append(sl.get_perplexity(sentence))
                        avg_sample_perplexity = sum(perplexities) / len(perplexities)
                        batch_perplexity.append(avg_sample_perplexity)
                    # Calculate batch perplexity
                    avg_batch_perplexity = sum(batch_perplexity) / len(batch_perplexity)

                    # loss calculation
                    NLL_loss, KL_loss, KL_weight, perp_loss, perp_weight = loss_fn(logp, batch['target'],
                        batch['length'], mean, logv, args.anneal_function, step, \
                            args.k, args.x0, avg_batch_perplexity, perplexity_anneal_function)

                    loss = ((NLL_loss + KL_weight * KL_loss) / batch_size) + (perp_loss * perp_weight)

                else: # Epochs < X, so train without perplexity
                    # loss calculation
                    NLL_loss, KL_loss, KL_weight, perp_loss, perp_weight = loss_fn(logp, batch['target'],
                        batch['length'], mean, logv, args.anneal_function, step, \
                            args.k, args.x0, 0, perplexity_anneal_function)

                    loss = (NLL_loss + KL_weight * KL_loss) / batch_size


                # Turn model back into train, since inference changed to eval
                if split == 'train':
                    model.train()
                else:
                    model.eval()

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                    # Add loss
                    epoch_loss.append(loss.item())

                # bookkeepeing
                tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data.view(1, -1)), dim=0)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO" % split.upper(), loss.item(), epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(), NLL_loss.item() / batch_size,
                                      epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(), KL_loss.item() / batch_size,
                                      epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(), KL_weight,
                                      epoch*len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                    print("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f, Perp-loss %9.4f, Perp-weight %6.3f"
                          % (split.upper(), iteration, len(data_loader)-1, loss.item(), NLL_loss.item()/batch_size,
                          KL_loss.item()/batch_size, KL_weight, perp_loss, perp_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(),
                                                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" % (split.upper(), epoch, args.epochs, tracker['ELBO'].mean()))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(), torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {'target_sents': tracker['target_sents'], 'z': tracker['z'].tolist()}
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/'+ts)
                with open(os.path.join('dumps/'+ts+'/valid_E%i.json' % epoch), 'w') as dump_file:
                    json.dump(dump,dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path, "E%i.pytorch" % epoch)
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)