def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def sigmoid(step):
        x = step - 6569.5
        if x < 0:
            a = np.exp(x)
            res = (a / (1 + a))
        else:
            res = (1 / (1 + np.exp(-x)))
        return float(res)

    def frange_cycle_linear(n_iter, start=0.0, stop=1.0, n_cycle=4, ratio=0.5):
        L = np.ones(n_iter) * stop
        period = n_iter / n_cycle
        step = (stop - start) / (period * ratio)  # linear schedule

        for c in range(n_cycle):
            v, i = start, 0
            while v <= stop and (int(i + c * period) < n_iter):
                L[int(i + c * period)] = v
                v += step
                i += 1
        return L

    n_iter = 0
    for epoch in range(args.epochs):
        split = 'train'
        data_loader = DataLoader(dataset=datasets[split],
                                 batch_size=args.batch_size,
                                 shuffle=split == 'train',
                                 num_workers=cpu_count(),
                                 pin_memory=torch.cuda.is_available())

        for iteration, batch in enumerate(data_loader):
            n_iter += 1
    print("Total no of iterations = " + str(n_iter))

    L = frange_cycle_linear(n_iter)

    def kl_anneal_function(anneal_function, step):
        if anneal_function == 'identity':
            return 1

        if anneal_function == 'sigmoid':
            return sigmoid(step)

        if anneal_function == 'cyclic':
            return float(L[step])

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp,
                target,
                length,
                mean,
                logv,
                anneal_function,
                step,
                split='train'):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        if split == 'train':
            KL_weight = kl_anneal_function(anneal_function, step)
        else:
            KL_weight = 1

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, split)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'] = torch.cat((tracker['negELBO'], loss.data))

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], recon_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs,
                         torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  torch.mean(tracker['negELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    curBest = 1000000
    for split in splits:
        datasets[split] = Mixed(data_dir=args.data_dir,
                                split=split,
                                create_data=args.create_data,
                                max_sequence_length=args.max_sequence_length,
                                min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, totalIterations, split):
        if (split != 'train'):
            return 1
        elif anneal_function == 'identity':
            return 1
        elif anneal_function == 'linear':
            return 1.005 * float(step) / totalIterations
        elif anneal_function == 'sigmoid':
            return (1 / (1 + math.exp(-8 * (float(step) / totalIterations))))
        elif anneal_function == 'tanh':
            return math.tanh(4 * (float(step) / totalIterations))
        elif anneal_function == 'linear_capped':
            #print(float(step)*30/totalIterations)
            return min(1.0, float(step) * 5 / totalIterations)
        elif anneal_function == 'cyclic':
            quantile = int(totalIterations / 5)
            remainder = int(step % quantile)
            midPoint = int(quantile / 2)
            if (remainder > midPoint):
                return 1
            else:
                return float(remainder) / midPoint
        else:
            return 1

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step,
                totalIterations, split):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        #print((1 + logv - mean.pow(2) - logv.exp()).size())

        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        #print(KL_loss.size())
        KL_weight = kl_anneal_function(anneal_function, step, totalIterations,
                                       split)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    tensor2 = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    tensor3 = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    tensor4 = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor

    step = 0
    stop = False
    Z = []
    L = []
    for epoch in range(args.epochs):
        if (stop):
            break
        for split in splits:
            if (split == 'test'):
                z_data = []
                domain_label = []
                z_bool = False
                domain_label_bool = False
            if (stop):
                break
            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            totalIterations = (int(len(datasets[split]) / args.batch_size) +
                               1) * args.epochs

            tracker = defaultdict(tensor)
            tracker2 = defaultdict(tensor2)
            tracker3 = defaultdict(tensor3)
            tracker4 = defaultdict(tensor4)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):
                #                 if(iteration > 400):
                #                     break
                batch_size = batch['input'].size(0)
                labels = batch['label']

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])
                if (split == 'test'):
                    if (z_bool == False):
                        z_bool = True
                        domain_label = labels.tolist()
                        z_data = z
                    else:
                        domain_label += labels.tolist()
                        #print(domain_label)
                        z_data = torch.cat((z_data, z), 0)

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, totalIterations, split)

                if split == 'train':
                    #KL_loss_thresholded = torch.clamp(KL_loss, min=6.0)
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'] = torch.cat((tracker['negELBO'], loss.data))
                tracker2['KL_loss'] = torch.cat(
                    (tracker2['KL_loss'], KL_loss.data))
                tracker3['Recon_loss'] = torch.cat(
                    (tracker3['Recon_loss'], recon_loss.data))
                tracker4['Perplexity'] = torch.cat(
                    (tracker4['Perplexity'],
                     torch.exp(recon_loss.data / batch_size)))

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], recon_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))

                if (split == 'test'):
                    Z = z_data
                    L = domain_label

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs,
                         torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  torch.mean(tracker['negELBO']), epoch)
                writer.add_scalar("%s-Epoch/KL_loss" % split.upper(),
                                  torch.mean(tracker2['KL_loss']) / batch_size,
                                  epoch)
                writer.add_scalar(
                    "%s-Epoch/Recon_loss" % split.upper(),
                    torch.mean(tracker3['Recon_loss']) / batch_size, epoch)
                writer.add_scalar("%s-Epoch/Perplexity" % split.upper(),
                                  torch.mean(tracker4['Perplexity']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                if (torch.mean(tracker['negELBO']) < curBest):
                    curBest = torch.mean(tracker['negELBO'])
                else:
                    stop = True
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps_32_0', ts)):
                    os.makedirs('dumps_32_0/' + ts)
                with open(
                        os.path.join('dumps_32_0/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            # if split == 'train':
            #     checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
            #     torch.save(model.state_dict(), checkpoint_path)
            #     logger.info("Model saved at %s"%checkpoint_path)

    Z = Z.data.cpu().numpy()
    print(Z.shape)
    beforeTSNE = TSNE(random_state=20150101).fit_transform(Z)
    scatter(beforeTSNE, L, [0, 1, 2], (5, 5), 'latent discoveries')
    plt.savefig('mixed_tsne' + args.anneal_function + '.png', dpi=120)
def main(args):

	ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

	splits = ['train', 'valid'] + (['test'] if args.test else [])

	datasets = OrderedDict()
	for split in splits:

		if args.dataset == 'ptb':
			Dataset = PTB
		elif args.dataset == 'twitter':
			Dataset = PoliticianTweets
		else:
			print("Invalid dataset. Exiting")
			exit()

		datasets[split] = Dataset(
			data_dir=args.data_dir,
			split=split,
			create_data=args.create_data,
			max_sequence_length=args.max_sequence_length,
			min_occ=args.min_occ
		)

	model = SentenceVAE(
		vocab_size=datasets['train'].vocab_size,
		sos_idx=datasets['train'].sos_idx,
		eos_idx=datasets['train'].eos_idx,
		pad_idx=datasets['train'].pad_idx,
		unk_idx=datasets['train'].unk_idx,
		max_sequence_length=args.max_sequence_length,
		embedding_size=args.embedding_size,
		rnn_type=args.rnn_type,
		hidden_size=args.hidden_size,
		word_dropout=args.word_dropout,
		embedding_dropout=args.embedding_dropout,
		latent_size=args.latent_size,
		num_layers=args.num_layers,
		bidirectional=args.bidirectional
		)

	# if args.from_file != "":
	# 	model = torch.load(args.from_file)
	#

	if torch.cuda.is_available():
		model = model.cuda()

	print(model)

	if args.tensorboard_logging:
		writer = SummaryWriter(os.path.join(args.logdir, experiment_name(args,ts)))
		writer.add_text("model", str(model))
		writer.add_text("args", str(args))
		writer.add_text("ts", ts)

	save_model_path = os.path.join(args.save_model_path, ts)
	os.makedirs(save_model_path)

	
	if 'sigmoid' in args.anneal_function and args.dataset=='ptb':
		linspace = np.linspace(-5,5,13160) # 13160 = number of training examples in ptb
	elif 'sigmoid' in args.anneal_function and args.dataset=='twitter':
		linspace = np.linspace(-5, 5, 25190) #6411/25190? = number of training examples in short version of twitter

	def kl_anneal_function(anneal_function, step, param_dict=None):
		if anneal_function == 'identity':
			return 1
		elif anneal_function == 'sigmoid' or anneal_function=='sigmoid_klt':
			s = 1/(len(linspace))
			return(float((1)/(1+np.exp(-param_dict['ag']*(linspace[step])))))

	NLL = torch.nn.NLLLoss(size_average=False, ignore_index=datasets['train'].pad_idx)
	def loss_fn(logp, target, length, mean, logv, anneal_function, step, param_dict=None):

		# cut-off unnecessary padding from target, and flatten
		target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
		logp = logp.view(-1, logp.size(2))
		
		# Negative Log Likelihood
		NLL_loss = NLL(logp, target)

		# KL Divergence
		KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
		if args.anneal_function == 'sigmoid_klt':
			if float(KL_loss)/args.batch_size < param_dict['kl_threshold']:
				# print("KL_loss of %s is below threshold %s. Returning this threshold instead"%(float(KL_loss)/args.batch_size,param_dict['kl_threshold']))
				KL_loss = to_var(torch.Tensor([param_dict['kl_threshold']*args.batch_size]))
		KL_weight = kl_anneal_function(anneal_function, step, {'ag': args.anneal_aggression})

		return NLL_loss, KL_loss, KL_weight

	optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

	tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
	step = 0
	for epoch in range(args.epochs):

		for split in splits:

			data_loader = DataLoader(
				dataset=datasets[split],
				batch_size=args.batch_size,
				shuffle=split=='train',
				num_workers=0,
				pin_memory=torch.cuda.is_available()
			)

			tracker = defaultdict(tensor)

			# Enable/Disable Dropout
			if split == 'train':
				model.train()
			else:
				model.eval()

			for iteration, batch in enumerate(data_loader):

				batch_size = batch['input'].size(0)
				if split == 'train' and batch_size != args.batch_size:
					print("WARNING: Found different batch size\nargs.batch_size= %s, input_size=%s"%(args.batch_size, batch_size))
					

				for k, v in batch.items():
					if torch.is_tensor(v):
						batch[k] = to_var(v)

				# Forward pass
				logp, mean, logv, z = model(batch['input'], batch['length'])

				# loss calculation
				NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
					batch['length'], mean, logv, args.anneal_function, step, {'kl_threshold': args.kl_threshold})

				loss = (NLL_loss + KL_weight * KL_loss)/batch_size

				# backward + optimization
				if split == 'train':
					optimizer.zero_grad()
					loss.backward()
					optimizer.step()
					step += 1
					# print(step)

				# bookkeepeing
				tracker['ELBO'] = torch.cat((tracker['ELBO'], loss.data))

				
				if args.tensorboard_logging:
					writer.add_scalar("%s/ELBO"%split.upper(), loss.data[0], epoch*len(data_loader) + iteration)
					writer.add_scalar("%s/NLL_Loss"%split.upper(), NLL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
					writer.add_scalar("%s/KL_Loss"%split.upper(), KL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
					# print("Step %s: %s"%(epoch*len(data_loader) + iteration, KL_weight))
					writer.add_scalar("%s/KL_Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)

				if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
					logger.info("%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
						%(split.upper(), iteration, len(data_loader)-1, loss.data[0], NLL_loss.data[0]/batch_size, KL_loss.data[0]/batch_size, KL_weight))

				if split == 'valid':
					if 'target_sents' not in tracker:
						tracker['target_sents'] = list()
					tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx)
					tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

			logger.info("%s Epoch %02d/%i, Mean ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['ELBO'])))

			if args.tensorboard_logging:
				writer.add_scalar("%s-Epoch/ELBO"%split.upper(), torch.mean(tracker['ELBO']), epoch)

			# save a dump of all sentences and the encoded latent space
			if split == 'valid':
				dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
				if not os.path.exists(os.path.join('dumps', ts)):
					os.makedirs('dumps/'+ts)
				with open(os.path.join('dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
					json.dump(dump,dump_file)

			# save checkpoint
			if split == 'train':
				checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
				torch.save(model.state_dict(), checkpoint_path)
				logger.info("Model saved at %s"%checkpoint_path)

	torch.save(model, f"model-{args.dataset}-{ts}.pickle")
示例#4
0
def main(args):
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())
    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = Gigaword(
            data_dir=args.data_dir,
            split=split,
            create_data=args.create_data,
            max_sequence_length=args.max_sequence_length,
            min_occ=args.min_occ)

    params = dict(vocab_size=datasets['train'].vocab_size,
                  sos_idx=datasets['train'].sos_idx,
                  eos_idx=datasets['train'].eos_idx,
                  pad_idx=datasets['train'].pad_idx,
                  unk_idx=datasets['train'].unk_idx,
                  max_sequence_length=args.max_sequence_length,
                  embedding_size=args.embedding_size,
                  rnn_type=args.rnn_type,
                  hidden_size=args.hidden_size,
                  word_dropout=args.word_dropout,
                  embedding_dropout=args.embedding_dropout,
                  latent_size=args.latent_size,
                  num_layers=args.num_layers,
                  bidirectional=args.bidirectional)
    model = SentenceVAE(**params)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    with open(os.path.join(save_model_path, 'model_params.json'), 'w') as f:
        json.dump(params, f, indent=4)

    def kl_anneal_function(anneal_function, step, k, x0):
        if anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-k * (step - x0))))
        elif anneal_function == 'linear':
            return min(1, step / x0)

    NLL = torch.nn.NLLLoss(ignore_index=datasets['train'].pad_idx,
                           reduction='sum')

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, k,
                x0):
        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))
        # Negative Log Likelihood
        NLL_loss = NLL(logp, target)
        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, k, x0)
        return NLL_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    for epoch in range(args.epochs):
        for split in splits:
            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())
            tracker = defaultdict(tensor)
            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):
                batch_size = batch['input'].size(0)
                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)
                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])
                # loss calculation
                NLL_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                                                       batch['length'], mean,
                                                       logv,
                                                       args.anneal_function,
                                                       step, args.k, args.x0)
                loss = (NLL_loss + KL_weight * KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['ELBO'] = torch.cat(
                    (tracker['ELBO'], loss.data.view(1, -1)), dim=0)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/ELBO" % split.upper(), loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/NLL Loss" % split.upper(),
                                      NLL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, NLL-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), NLL_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            print("%s Epoch %02d/%i, Mean ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs, tracker['ELBO'].mean()))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/ELBO" % split.upper(),
                                  torch.mean(tracker['ELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % epoch)
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)
示例#5
0
文件: train.py 项目: timbmg/DIAL-LV
def main(args):

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor

    splits = ['train', 'valid']

    datasets = OrderedDict()
    for split in splits:
        if args.dataset.lower() == 'opensubtitles':
            datasets[split] = OpenSubtitlesQADataset(
                root='data',
                split=split,
                min_occ=args.min_occ,
                max_prompt_length=args.max_input_length,
                max_reply_length=args.max_reply_length
                )
        elif args.dataset.lower() == 'guesswhat':
            datasets[split] = GuessWhatDataset(
                root='data',
                split=split,
                min_occ=args.min_occ,
                max_dialogue_length=args.max_input_length,
                max_question_length=args.max_reply_length
                )

    model = DialLV(vocab_size=datasets['train'].vocab_size,
                    embedding_size=args.embedding_size,
                    hidden_size=args.hidden_size,
                    latent_size=args.latent_size,
                    word_dropout=args.word_dropout,
                    pad_idx=datasets['train'].pad_idx,
                    sos_idx=datasets['train'].sos_idx,
                    eos_idx=datasets['train'].eos_idx,
                    max_utterance_length=args.max_reply_length,
                    bidirectional=args.bidirectional_encoder
                    )

    if args.load_checkpoint != '':
        if not os.path.exists(args.load_checkpoint):
            raise FileNotFoundError(args.load_checkpoint)

        model.load_state_dict(torch.load(args.load_checkpoint))
        print("Model loaded from %s"%(args.load_checkpoint))

    if torch.cuda.is_available():
        model = model.cuda()
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    NLL = torch.nn.NLLLoss(size_average=False)

    def kl_anneal_function(**kwargs):
        """ Returns the weight of for calcualting the weighted KL Divergence."""

        if kwargs['kl_anneal'] == 'logistic':
            """ https://en.wikipedia.org/wiki/Logistic_function """
            assert ('k' in kwargs and 'x0' in kwargs and 'global_step' in kwargs)
            return float(1/(1+np.exp(-kwargs['k']*(kwargs['global_step']-kwargs['x0']))))

        elif kwargs['kl_anneal'] == 'step':
            assert ('epoch' in kwargs and 'denom' in kwargs)
            return kwargs['epoch'] / kwargs['denom']

        else:
            # Disable KL Annealing
            return 1

    def loss_fn(predictions, targets, mean, log_var, **kl_args):
        """Calcultes the ELBO, consiting of the Negative Log Likelihood and KL Divergence.

        Parameters
        ----------
        predictions : Variable(torch.FloatTensor) [? x vocab_size]
            Log probabilites of each generated token in the batch. Number of tokens depends on
            tokens in batch.
        targets : Variable(torch.LongTensor) [?]
            Target token ids. Number of tokens depends on tokens in batch.
        mean : Variable(torch.FloatTensor) [batch_size x latent_size]
            Predicted mean values of latent variables.
        log_var : Variable(torch.FloatTensor) [batch_size x latent_size]
            Predicted log variabnce values of latent variables.
        k : float
            Steepness parameter for kl weight calculation.
        x0 : int
            Midpoint parameter for kl weight calculation.
        x : int
            Global step.

        Returns
        -------
        Variable(torch.FloatTensor), Variable(torch.FloatTensor), float, Variable(torch.FloatTensor)
            NLLLoss value, weighted KL Divergence loss, weight value and unweighted KL Divergence.

        """

        nll_loss = NLL(predictions, targets)

        kl_loss = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())

        kl_weight = kl_anneal_function(**kl_args)

        kl_weighted = kl_weight * kl_loss

        return nll_loss, kl_weighted, kl_weight, kl_loss

    def inference(model, train_dataset, split, n=10, m=3):
        """ Executes the model in inference mode and returns string of inputs and corresponding
        generations.

        Parameters
        ----------
        model : DIAL-LV
            The DIAL-LV model.
        train_dataset : Dataset
            Training dataset to draw random input samples from.
        split : str
            'train', 'valid' or 'test', to enable/disable word_dropout.
        n : int
            Number of samples to draw.
        m : int
            Number of response generations.

        Returns
        -------
        string, string
            Two string, each consiting of n utterances. `Prompts` contains the input sequence and
            `replies` the generated response sequence.

        """

        random_input_idx = np.random.choice(np.arange(0, len(train_dataset)), 10, replace=False).astype('int64')
        random_inputs = np.zeros((n, args.max_input_length)).astype('int64')
        random_inputs_length = np.zeros(n)
        for i, rqi in enumerate(random_input_idx):
            random_inputs[i] = train_dataset[rqi]['input_sequence']
            random_inputs_length[i] = train_dataset[rqi]['input_length']

        input_sequence = to_var(torch.from_numpy(random_inputs).long())
        input_length = to_var(torch.from_numpy(random_inputs_length).long())
        prompts = idx2word(input_sequence.data, train_dataset.i2w, train_dataset.pad_idx)

        replies = list()
        if split == 'train':
            model.eval()
        for i in range(m):
            replies_ = model.inference(input_sequence, input_length)
            replies.append(idx2word(replies_, train_dataset.i2w, train_dataset.pad_idx))

        if split == 'train':
            model.train()

        return prompts, replies

    ts = time.strftime('%Y-%b-%d|%H:%M:%S', time.gmtime())
    if args.tensorboard_logging:
        log_path = os.path.join(args.tensorboard_logdir, experiment_name(args, ts))
        while os.path.exists(log_path):
            ts = time.strftime('%Y-%b-%d|%H:%M:%S', time.gmtime())
            log_path = os.path.join(args.tensorboard_logdir, experiment_name(args, ts))

        writer = SummaryWriter(log_path)
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)
        if args.load_checkpoint != '':
            writer.add_text("Loaded From", args.load_checkpoint)
    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    global_step = 0
    for epoch in range(args.epochs):

        for split, dataset in datasets.items():

            data_loader = DataLoader(
                dataset=dataset,
                batch_size=args.batch_size,
                shuffle=split=='train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available()
                )

            tracker = defaultdict(tensor)

            if split == 'train':
                model.train()
            else:
                # disable drop out when in validation
                model.eval()

            t1 = time.time()
            for iteration, batch in enumerate(data_loader):

                # get batch items and wrap them in variables
                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                input_sequence = batch['input_sequence']
                input_length = batch['input_length']
                reply_sequence_in = batch['reply_sequence_in']
                reply_sequence_out = batch['reply_sequence_out']
                reply_length = batch['reply_length']
                batch_size = input_sequence.size(0)


                # model forward pass
                predictions, mean, log_var = model(
                    prompt_sequece=input_sequence,
                    prompt_length=input_length,
                    reply_sequence=reply_sequence_in,
                    reply_length=reply_length
                    )

                # predictions come back packed, so making targets packed as well to ignore all padding tokens
                sorted_length, sort_idx = reply_length.sort(0, descending=True)
                targets = reply_sequence_out[sort_idx]
                targets = pack_padded_sequence(targets, sorted_length.data.tolist(), batch_first=True)[0]

                # compute the loss
                nll_loss, kl_weighted_loss, kl_weight, kl_loss = loss_fn(
                    predictions, targets, mean, log_var, kl_anneal=args.kl_anneal,
                    global_step=global_step, epoch=epoch, k=args.kla_k, x0=args.kla_x0,
                    denom=args.kla_denom
                    )
                loss = nll_loss + kl_weighted_loss

                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    global_step += 1

                # bookkeeping
                tracker['loss']             = torch.cat((tracker['loss'],               loss.data/batch_size))
                tracker['nll_loss']         = torch.cat((tracker['nll_loss'],           nll_loss.data/batch_size))
                tracker['kl_loss']          = torch.cat((tracker['kl_loss'],            kl_loss.data/batch_size))
                tracker['kl_weight']        = torch.cat((tracker['kl_weight'],          tensor([kl_weight])))
                tracker['kl_weighted_loss'] = torch.cat((tracker['kl_weighted_loss'],   kl_weighted_loss.data/batch_size))

                if args.tensorboard_logging:
                    step = epoch * len(data_loader) + iteration
                    writer.add_scalar("%s/Batch-Loss"%(split),              tracker['loss'][-1],                step)
                    writer.add_scalar("%s/Batch-NLL-Loss"%(split),          tracker['nll_loss'][-1],            step)
                    writer.add_scalar("%s/Batch-KL-Loss"%(split),           tracker['kl_loss'][-1],             step)
                    writer.add_scalar("%s/Batch-KL-Weight"%(split),         tracker['kl_weight'][-1],           step)
                    writer.add_scalar("%s/Batch-KL-Loss-Weighted"%(split),  tracker['kl_weighted_loss'][-1],    step)

                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                    print("%s Batch %04d/%i, Loss %9.4f, NLL Loss %9.4f, KL Loss %9.4f, KLW Loss %9.4f, w %6.4f, tt %6.2f"
                        %(split.upper(), iteration, len(data_loader),
                        tracker['loss'][-1], tracker['nll_loss'][-1], tracker['kl_loss'][-1],
                        tracker['kl_weighted_loss'][-1], tracker['kl_weight'][-1], time.time()-t1))


                    t1 = time.time()

                    prompts, replies = inference(model, datasets[split], split)
                    save_dial_to_json(prompts, replies, root="dials/"+ts+"/", comment="%s_E%i_I%i"%(split.lower(), epoch, iteration))


            print("%s Epoch %02d/%i, Mean Loss: %.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['loss'])))
            if args.tensorboard_logging:
                writer.add_scalar("%s/Epoch-Loss"%(split),      torch.mean(tracker['loss']),        epoch)
                writer.add_scalar("%s/Epoch-NLL-Loss"%(split),  torch.mean(tracker['nll_loss']),    epoch)
                writer.add_scalar("%s/Epoch-KL-Loss"%(split),   torch.mean(tracker['kl_loss']),     epoch)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path, "E%i.pytorch"%(epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s"%checkpoint_path)
示例#6
0
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid']

    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
        filename=os.path.join(args.logdir,
                              experiment_name(args, ts) + ".log"))
    logger = logging.getLogger(__name__)

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    total_step = int(args.epochs * 42000.0 / args.batch_size)

    def kl_anneal_function(anneal_function, step):
        if anneal_function == 'half':
            return 0.5
        if anneal_function == 'identity':
            return 1
        if anneal_function == 'double':
            return 2
        if anneal_function == 'quadra':
            return 4

        if anneal_function == 'sigmoid':
            return 1 / (1 + np.exp((0.5 * total_step - step) / 200))

        if anneal_function == 'monotonic':
            beta = step * 4 / total_step
            if beta > 1:
                beta = 1.0
            return beta

        if anneal_function == 'cyclical':
            t = total_step / 4
            beta = 4 * (step % t) / t
            if beta > 1:
                beta = 1.0
            return beta

    ReconLoss = torch.nn.NLLLoss(reduction='sum',
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    train_loss = []
    test_loss = []
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(list)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                # tracker['negELBO'] = torch.cat((tracker['negELBO'], loss.data))
                tracker["negELBO"].append(loss.item())
                tracker["recon_loss"].append(recon_loss.item() / batch_size)
                tracker["KL_Loss"].append(KL_loss.item() / batch_size)
                tracker["KL_Weight"].append(KL_weight)

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "\tStep\t%s\t%04d\t%i\t%9.4f\t%9.4f\t%9.4f\t%6.3f" %
                        (split.upper(), iteration, len(data_loader) - 1,
                         loss.item(), recon_loss.item() / batch_size,
                         KL_loss.item() / batch_size, KL_weight))
                    print(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), recon_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'].append(z.data.tolist())

            logger.info(
                "\tEpoch\t%s\t%02d\t%i\t%9.4f\t%9.4f\t%9.4f\t%6.3f" %
                (split.upper(), epoch, args.epochs,
                 sum(tracker['negELBO']) / len(tracker['negELBO']),
                 1.0 * sum(tracker['recon_loss']) / len(tracker['recon_loss']),
                 1.0 * sum(tracker['KL_Loss']) / len(tracker['KL_Loss']),
                 1.0 * sum(tracker['KL_Weight']) / len(tracker['KL_Weight'])))
            print("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                  (split.upper(), epoch, args.epochs,
                   sum(tracker['negELBO']) / len(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar(
                    "%s-Epoch/NegELBO" % split.upper(),
                    1.0 * sum(tracker['negELBO']) / len(tracker['negELBO']),
                    epoch)
                writer.add_scalar(
                    "%s-Epoch/recon_loss" % split.upper(), 1.0 *
                    sum(tracker['recon_loss']) / len(tracker['recon_loss']),
                    epoch)
                writer.add_scalar(
                    "%s-Epoch/KL_Loss" % split.upper(),
                    1.0 * sum(tracker['KL_Loss']) / len(tracker['KL_Loss']),
                    epoch)
                writer.add_scalar(
                    "%s-Epoch/KL_Weight" % split.upper(), 1.0 *
                    sum(tracker['KL_Weight']) / len(tracker['KL_Weight']),
                    epoch)

            if split == 'train':
                train_loss.append(1.0 * sum(tracker['negELBO']) /
                                  len(tracker['negELBO']))
            else:
                test_loss.append(1.0 * sum(tracker['negELBO']) /
                                 len(tracker['negELBO']))
            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z']
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                print("Model saved at %s" % checkpoint_path)

    sns.set(style="whitegrid")
    df = pd.DataFrame()
    df["train"] = train_loss
    df["test"] = test_loss
    ax = sns.lineplot(data=df, legend=False)
    ax.set(xlabel='Epoch', ylabel='Loss')
    plt.legend(title='Split', loc='upper right', labels=['Train', 'Test'])
    plt.savefig(os.path.join(args.logdir,
                             experiment_name(args, ts) + ".png"),
                transparent=True,
                dpi=300)
示例#7
0
文件: train.py 项目: pvijayak/RNNVaE
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, x1, x2):
        if anneal_function == 'identity':
            return 1
        elif anneal_function == 'linear':
            return min(1, step / x1)
        elif anneal_function == 'logistic':
            return float(1 / (1 + np.exp(-x2 * (step - x1))))
        elif anneal_function == 'cyclic_log':
            return float(1 / (1 + np.exp(-x2 * ((step % (3 * x1)) - x1))))
        elif anneal_function == 'cyclic_lin':
            return min(1, (step % (3 * x1)) / x1)

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step, x1,
                x2):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).item()].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step, x1, x2)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0

    early_stopping = EarlyStopping(history=10)
    for epoch in range(args.epochs):

        early_stopping_flag = False
        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            # tracker = defaultdict(tensor)
            tracker = defaultdict(list)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step, args.x1, args.x2)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'].append(loss.item())

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      loss.item(),
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.item() / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    # print(step)
                    # logger.info("Step = %d"%step)
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.item(), recon_loss.item() / batch_size,
                           KL_loss.item() / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    # tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)
                    # print(z.data.shape)
                    tracker['z'].append(z.data.tolist())
            mean_loss = sum(tracker['negELBO']) / len(tracker['negELBO'])

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs, mean_loss))
            # print(mean_loss)

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  mean_loss, epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z']
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)
                if (args.early_stopping):
                    if (early_stopping.check(mean_loss)):
                        early_stopping_flag = True

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)

        if (early_stopping_flag):
            print("Early stopping trigerred. Training stopped...")
            break
def main(args):
    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(data_dir=args.data_dir,
                              split=split,
                              create_data=args.create_data,
                              max_sequence_length=args.max_sequence_length,
                              min_occ=args.min_occ)

    model = SentenceVAE(vocab_size=datasets['train'].vocab_size,
                        sos_idx=datasets['train'].sos_idx,
                        eos_idx=datasets['train'].eos_idx,
                        pad_idx=datasets['train'].pad_idx,
                        unk_idx=datasets['train'].unk_idx,
                        max_sequence_length=args.max_sequence_length,
                        embedding_size=args.embedding_size,
                        rnn_type=args.rnn_type,
                        hidden_size=args.hidden_size,
                        word_dropout=args.word_dropout,
                        embedding_dropout=args.embedding_dropout,
                        latent_size=args.latent_size,
                        num_layers=args.num_layers,
                        bidirectional=args.bidirectional)

    if torch.cuda.is_available():
        model = model.cuda()

    print(model)

    if args.tensorboard_logging:
        writer = SummaryWriter(
            os.path.join(args.logdir, experiment_name(args, ts)))
        writer.add_text("model", str(model))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    total_steps = (len(datasets["train"]) // args.batch_size) * args.epochs
    print("Train dataset size", total_steps)

    def kl_anneal_function(anneal_function, step):
        if anneal_function == 'identity':
            return 1
        if anneal_function == 'linear':
            if args.warmup is None:
                return 1 - (total_steps - step) / total_steps
            else:
                warmup_steps = (total_steps / args.epochs) * args.warmup
                return 1 - (warmup_steps - step
                            ) / warmup_steps if step < warmup_steps else 1.0

    ReconLoss = torch.nn.NLLLoss(size_average=False,
                                 ignore_index=datasets['train'].pad_idx)

    def loss_fn(logp, target, length, mean, logv, anneal_function, step):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))

        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        KL_weight = kl_anneal_function(anneal_function, step)

        return recon_loss, KL_loss, KL_weight

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available(
    ) else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(dataset=datasets[split],
                                     batch_size=args.batch_size,
                                     shuffle=split == 'train',
                                     num_workers=cpu_count(),
                                     pin_memory=torch.cuda.is_available())

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                model.train()
            else:
                model.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                logp, mean, logv, z = model(batch['input'], batch['length'])

                # loss calculation
                recon_loss, KL_loss, KL_weight = loss_fn(
                    logp, batch['target'], batch['length'], mean, logv,
                    args.anneal_function, step)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss) / batch_size
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss) / batch_size

                # backward + optimization
                if split == 'train':
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    step += 1

                # bookkeepeing
                tracker['negELBO'] = torch.cat(
                    (tracker['negELBO'], loss.data.unsqueeze(0)))

                if args.tensorboard_logging:
                    neg_elbo = (recon_loss + KL_loss) / batch_size
                    writer.add_scalar("%s/Negative_ELBO" % split.upper(),
                                      neg_elbo.data[0],
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss" % split.upper(),
                                      recon_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss" % split.upper(),
                                      KL_loss.data[0] / batch_size,
                                      epoch * len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight" % split.upper(),
                                      KL_weight,
                                      epoch * len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration + 1 == len(
                        data_loader):
                    logger.info(
                        "%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        % (split.upper(), iteration, len(data_loader) - 1,
                           loss.data[0], recon_loss.data[0] / batch_size,
                           KL_loss.data[0] / batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(
                        batch['target'].data,
                        i2w=datasets['train'].get_i2w(),
                        pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f" %
                        (split.upper(), epoch, args.epochs,
                         torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO" % split.upper(),
                                  torch.mean(tracker['negELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {
                    'target_sents': tracker['target_sents'],
                    'z': tracker['z'].tolist()
                }
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/' + ts)
                with open(
                        os.path.join('dumps/' + ts +
                                     '/valid_E%i.json' % epoch),
                        'w') as dump_file:
                    json.dump(dump, dump_file)

            # save checkpoint
            if split == 'train':
                checkpoint_path = os.path.join(save_model_path,
                                               "E%i.pytorch" % (epoch))
                torch.save(model.state_dict(), checkpoint_path)
                logger.info("Model saved at %s" % checkpoint_path)

    if args.num_samples:
        torch.cuda.empty_cache()
        model.eval()
        with torch.no_grad():
            print(f"Generating {args.num_samples} samples")
            generations, _ = model.inference(n=args.num_samples)
            vocab = datasets["train"].i2w

            print(
                "Sampled latent codes from z ~ N(0, I), generated sentences:")
            for i, generation in enumerate(generations, start=1):
                sentence = [vocab[str(word.item())] for word in generation]
                print(f"{i}:", " ".join(sentence))
def main(args):

    ts = time.strftime('%Y-%b-%d-%H:%M:%S', time.gmtime())

    splits = ['train', 'valid'] + (['test'] if args.test else [])

    datasets = OrderedDict()
    for split in splits:
        datasets[split] = PTB(
            data_dir=args.data_dir,
            split=split,
            create_data=args.create_data,
            max_sequence_length=args.max_sequence_length,
            min_occ=args.min_occ
        )

    encoderVAE = EncoderVAE(
        vocab_size=datasets['train'].vocab_size,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        unk_idx=datasets['train'].unk_idx,
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )
    
    decoderVAE = DecoderVAE(
        vocab_size=datasets['train'].vocab_size,
        sos_idx=datasets['train'].sos_idx,
        eos_idx=datasets['train'].eos_idx,
        pad_idx=datasets['train'].pad_idx,
        unk_idx=datasets['train'].unk_idx,
        max_sequence_length=args.max_sequence_length,
        embedding_size=args.embedding_size,
        rnn_type=args.rnn_type,
        hidden_size=args.hidden_size,
        word_dropout=args.word_dropout,
        embedding_dropout=args.embedding_dropout,
        latent_size=args.latent_size,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional
        )
    
    
    

    if torch.cuda.is_available():
        encoderVAE = encoderVAE.cuda()
        decoderVAE = decoderVAE.cuda()

    if args.tensorboard_logging:
        writer = SummaryWriter(os.path.join(args.logdir, experiment_name(args,ts)))
        #writer.add_text("model", str(mode))
        writer.add_text("args", str(args))
        writer.add_text("ts", ts)

    save_model_path = os.path.join(args.save_model_path, ts)
    os.makedirs(save_model_path)

    def kl_anneal_function(anneal_function, step, totalIterations, split):
        if(split != 'train'):
            return 1
        elif anneal_function == 'identity':
            return 1
        elif anneal_function == 'linear':
            return 1.005*float(step)/totalIterations
        elif anneal_function == 'sigmoid':
            return (1/(1 + math.exp(-8*(float(step)/totalIterations))))
        elif anneal_function == 'tanh':
            return math.tanh(4*(float(step)/totalIterations))
        elif anneal_function == 'linear_capped':
            #print(float(step)*30/totalIterations)
            return min(1.0, float(step)*5/totalIterations)
        elif anneal_function == 'cyclic':
            quantile = int(totalIterations/5)
            remainder = int(step % quantile)
            midPoint = int(quantile/2)
            if(remainder > midPoint):
                return 1
            else:
                return float(remainder)/midPoint 
        else:
            return 1

    ReconLoss = torch.nn.NLLLoss(size_average=False, ignore_index=datasets['train'].pad_idx)
    def loss_fn(logp, target, length, mean, logv, anneal_function, step, totalIterations, split):

        # cut-off unnecessary padding from target, and flatten
        target = target[:, :torch.max(length).data[0]].contiguous().view(-1)
        logp = logp.view(-1, logp.size(2))
        
        # Negative Log Likelihood
        recon_loss = ReconLoss(logp, target)

        # KL Divergence
        #print((1 + logv - mean.pow(2) - logv.exp()).size())

        KL_loss = -0.5 * torch.sum(1 + logv - mean.pow(2) - logv.exp())
        #print(KL_loss.size())
        KL_weight = kl_anneal_function(anneal_function, step, totalIterations, split)

        return recon_loss, KL_loss, KL_weight

    encoderOptimizer = torch.optim.Adam(encoderVAE.parameters(), lr=args.learning_rate)
    decoderOptimizer = torch.optim.Adam(decoderVAE.parameters(), lr=args.learning_rate)

    tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.Tensor
    step = 0
    for epoch in range(args.epochs):

        for split in splits:

            data_loader = DataLoader(
                dataset=datasets[split],
                batch_size=args.batch_size,
                shuffle=split=='train',
                num_workers=cpu_count(),
                pin_memory=torch.cuda.is_available()
            )
            
            totalIterations = (int(len(datasets[split])/args.batch_size) + 1)*args.epochs

            tracker = defaultdict(tensor)

            # Enable/Disable Dropout
            if split == 'train':
                encoderVAE.train()
                decoderVAE.train()
            else:
                encoderVAE.eval()
                decoderVAE.eval()

            for iteration, batch in enumerate(data_loader):

                batch_size = batch['input'].size(0)

                for k, v in batch.items():
                    if torch.is_tensor(v):
                        batch[k] = to_var(v)

                # Forward pass
                hidden, mean, logv, z = encoderVAE(batch['input'], batch['length'])

                # loss calculation
                logp = decoderVAE(batch['input'], batch['length'], hidden)
                
                recon_loss, KL_loss, KL_weight = loss_fn(logp, batch['target'],
                    batch['length'], mean, logv, args.anneal_function, step, totalIterations, split)

                if split == 'train':
                    loss = (recon_loss + KL_weight * KL_loss)/batch_size
                    negELBO = loss
                else:
                    # report complete elbo when validation
                    loss = (recon_loss + KL_loss)/batch_size
                    negELBO = loss

                # backward + optimization
                if split == 'train':
                    encoderOptimizer.zero_grad()
                    decoderOptimizer.zero_grad()
                    loss.backward()
                    if(step < 500):
                        encoderOptimizer.step()
                    else:
                        encoderOptimizer.step()
                        decoderOptimizer.step()
                        
                    #optimizer.step()
                    step += 1


                # bookkeepeing
                tracker['negELBO'] = torch.cat((tracker['negELBO'], negELBO.data))

                if args.tensorboard_logging:
                    writer.add_scalar("%s/Negative_ELBO"%split.upper(), negELBO.data[0], epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/Recon_Loss"%split.upper(), recon_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Loss"%split.upper(), KL_loss.data[0]/batch_size, epoch*len(data_loader) + iteration)
                    writer.add_scalar("%s/KL_Weight"%split.upper(), KL_weight, epoch*len(data_loader) + iteration)

                if iteration % args.print_every == 0 or iteration+1 == len(data_loader):
                    logger.info("%s Batch %04d/%i, Loss %9.4f, Recon-Loss %9.4f, KL-Loss %9.4f, KL-Weight %6.3f"
                        %(split.upper(), iteration, len(data_loader)-1, negELBO.data[0], recon_loss.data[0]/batch_size, KL_loss.data[0]/batch_size, KL_weight))

                if split == 'valid':
                    if 'target_sents' not in tracker:
                        tracker['target_sents'] = list()
                    tracker['target_sents'] += idx2word(batch['target'].data, i2w=datasets['train'].get_i2w(), pad_idx=datasets['train'].pad_idx)
                    tracker['z'] = torch.cat((tracker['z'], z.data), dim=0)

            logger.info("%s Epoch %02d/%i, Mean Negative ELBO %9.4f"%(split.upper(), epoch, args.epochs, torch.mean(tracker['negELBO'])))

            if args.tensorboard_logging:
                writer.add_scalar("%s-Epoch/NegELBO"%split.upper(), torch.mean(tracker['negELBO']), epoch)

            # save a dump of all sentences and the encoded latent space
            if split == 'valid':
                dump = {'target_sents':tracker['target_sents'], 'z':tracker['z'].tolist()}
                if not os.path.exists(os.path.join('dumps', ts)):
                    os.makedirs('dumps/'+ts)
                with open(os.path.join('dumps/'+ts+'/valid_E%i.json'%epoch), 'w') as dump_file:
                    json.dump(dump,dump_file)