def train(model_config, train_config):
	mode = 'train'

	dataset = ShakespeareModern(train_shakespeare_path, test_shakespeare_path, train_modern_path, test_modern_path, mode=mode)	
	dataloader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=False)
	print(dataset.domain_A_max_len)
	shakespeare_disc = Discriminator(model_config['embedding_size'], model_config['hidden_dim'], len(dataset.vocab), batch_size=train_config['batch_size']).cuda()
	shakespeare_disc.train()

	if train_config['continue_train']:
		shakespeare_disc.load_state_dict(torch.load(train_config['model_path']))

	criterion = nn.BCELoss().cuda()
	optimizer = torch.optim.Adam(shakespeare_disc.parameters(), lr=train_config['base_lr'],
								 weight_decay=1e-5)

	real_label = torch.ones((train_config['batch_size'], 1)).cuda()
	fake_label = torch.zeros((train_config['batch_size'], 1)).cuda()

	for epoch in range(train_config['num_epochs']):
		for idx, (s, s_addn_feats, m, m_addn_feats) in tqdm(enumerate(dataloader)):
			s = s.transpose(0, 1)
			m = m.transpose(0, 1)

			s = Variable(s).cuda()
			s_output = shakespeare_disc(s, s_addn_feats)
			s_loss = criterion(s_output, real_label)
			s_loss = 100 * s_loss
			optimizer.zero_grad()
			s_loss.backward()
			optimizer.step()
			shakespeare_disc.hidden = shakespeare_disc.init_hidden()

			m = Variable(m).cuda()
			m_output = shakespeare_disc(m, m_addn_feats)
			m_loss = criterion(m_output, fake_label)
			m_loss = 100 * m_loss
			optimizer.zero_grad()
			m_loss.backward()
			optimizer.step()
			shakespeare_disc.hidden = shakespeare_disc.init_hidden()

			if idx % 100 == 0:
				print('\tepoch [{}/{}], iter: {}, s_loss: {:.4f}, m_loss: {:.4f}, preds: s: {}, {}, m: {}, {}'
					.format(epoch+1, train_config['num_epochs'], idx, s_loss.item(), m_loss.item(), s_output.item(), round(s_output.item()), m_output.item(), round(m_output.item())))

		print('\tepoch [{}/{}]'.format(epoch+1, train_config['num_epochs']))

		torch.save(shakespeare_disc.state_dict(), './shakespeare_disc.pth')
Esempio n. 2
0
cls_criterion = nn.NLLLoss().to(opts.device)
mse_criterion = nn.MSELoss().to(opts.device)
noise = torch.FloatTensor(opts.batch_size, opts.nz).to(opts.device)
input_res = torch.FloatTensor(opts.batch_size, opts.f_dim).to(opts.device)
input_att = torch.FloatTensor(opts.batch_size, opts.atts_dim).to(opts.device)
input_label = torch.LongTensor(opts.batch_size).to(opts.device)


# training and test
seenclasses = data.seenclasses.to(opts.device)
unseenclasses = data.unseenclasses.to(opts.device)
for epoch in range(opts.nepoch):
    netRS.to(opts.device)
    netRU.to(opts.device)
    for i in range(0, data.ntrain, opts.batch_size):
        netD.train()
        netG.eval()
        # train Discriminator
        for iter_d in range(opts.critic_iter):
            batch_feat, batch_l, batch_att = data.next_batch(opts.batch_size)
            batch_feat = batch_feat.to(opts.device)
            batch_l = batch_l.to(opts.device)
            batch_att = batch_att.to(opts.device)

            netD.zero_grad()
            
            # real loss
            criticD_real = netD(batch_feat, batch_att)
            criticD_real = - criticD_real.mean()
            criticD_real.backward()
Esempio n. 3
0
# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor
input_source = Tensor(args.batch_size, args.input_nc, args.size, args.size)
input_target = Tensor(args.batch_size, args.output_nc, args.size, args.size)
target_real = Variable(Tensor(args.batch_size).fill_(1.0), requires_grad=False)
target_fake = Variable(Tensor(args.batch_size).fill_(0.0), requires_grad=False)
fake_img_buffer = ReplayBuffer()

# perceptual loss models:
vgg = VGGFeature().cuda()

###### Training ######
print('dataloader:', len(dataloader))  # 1334
for epoch in range(start_epoch, args.epochs):
    start_time = time.time()
    netG.train(), netD.train()
    # define average meters:
    loss_G_meter, loss_G_perceptual_meter, loss_G_GAN_meter, loss_D_meter = \
        AverageMeter(), AverageMeter(), AverageMeter(), AverageMeter()
    for i, batch in enumerate(dataloader):
        # Set model input
        input_img = Variable(input_source.copy_(batch[source_str]))  # X
        teacher_output_img = Variable(input_target.copy_(
            batch[target_str]))  # Gt(X)

        # print('input_img:', input_img.size(), torch.max(input_img), torch.min(input_img))
        # print('teacher_output_img:', teacher_output_img.size(), torch.max(teacher_output_img), torch.min(teacher_output_img))

        ###### G ######
        optimizer_G.zero_grad()
        optimizer_gamma.zero_grad()