def test(model_config):
	mode = 'test'
	batch_size = 1
	dataset = ShakespeareModern(train_shakespeare_path, test_shakespeare_path, train_modern_path, test_modern_path, mode=mode)
	dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

	shakespeare_disc = Discriminator(model_config['embedding_size'], model_config['hidden_dim'], len(dataset.vocab)).cuda()
	shakespeare_disc.load_state_dict(torch.load('./shakespeare_disc.pth'))

	shakespeare_disc.eval()

	num_correct = 0
	total_samples = 0

	for idx, (s, s_addn_feats, m, m_addn_feats) in tqdm(enumerate(dataloader)):
		s = s.transpose(0, 1)
		m = m.transpose(0, 1)

		total_samples += 2

		s = Variable(s).cuda()
		s_output = shakespeare_disc(s, s_addn_feats)

		if round(s_output.item()) == 1.0:
			num_correct += 1

		m = Variable(m).cuda()
		m_output = shakespeare_disc(m, m_addn_feats)

		if round(m_output.item()) == 0.0:
			num_correct += 1

	print ('Accuracy: {}'.format(num_correct/total_samples))
def train(model_config, train_config):
	mode = 'train'

	dataset = ShakespeareModern(train_shakespeare_path, test_shakespeare_path, train_modern_path, test_modern_path, mode=mode)	
	dataloader = DataLoader(dataset, batch_size=train_config['batch_size'], shuffle=False)
	print(dataset.domain_A_max_len)
	shakespeare_disc = Discriminator(model_config['embedding_size'], model_config['hidden_dim'], len(dataset.vocab), batch_size=train_config['batch_size']).cuda()
	shakespeare_disc.train()

	if train_config['continue_train']:
		shakespeare_disc.load_state_dict(torch.load(train_config['model_path']))

	criterion = nn.BCELoss().cuda()
	optimizer = torch.optim.Adam(shakespeare_disc.parameters(), lr=train_config['base_lr'],
								 weight_decay=1e-5)

	real_label = torch.ones((train_config['batch_size'], 1)).cuda()
	fake_label = torch.zeros((train_config['batch_size'], 1)).cuda()

	for epoch in range(train_config['num_epochs']):
		for idx, (s, s_addn_feats, m, m_addn_feats) in tqdm(enumerate(dataloader)):
			s = s.transpose(0, 1)
			m = m.transpose(0, 1)

			s = Variable(s).cuda()
			s_output = shakespeare_disc(s, s_addn_feats)
			s_loss = criterion(s_output, real_label)
			s_loss = 100 * s_loss
			optimizer.zero_grad()
			s_loss.backward()
			optimizer.step()
			shakespeare_disc.hidden = shakespeare_disc.init_hidden()

			m = Variable(m).cuda()
			m_output = shakespeare_disc(m, m_addn_feats)
			m_loss = criterion(m_output, fake_label)
			m_loss = 100 * m_loss
			optimizer.zero_grad()
			m_loss.backward()
			optimizer.step()
			shakespeare_disc.hidden = shakespeare_disc.init_hidden()

			if idx % 100 == 0:
				print('\tepoch [{}/{}], iter: {}, s_loss: {:.4f}, m_loss: {:.4f}, preds: s: {}, {}, m: {}, {}'
					.format(epoch+1, train_config['num_epochs'], idx, s_loss.item(), m_loss.item(), s_output.item(), round(s_output.item()), m_output.item(), round(m_output.item())))

		print('\tepoch [{}/{}]'.format(epoch+1, train_config['num_epochs']))

		torch.save(shakespeare_disc.state_dict(), './shakespeare_disc.pth')
Ejemplo n.º 3
0
        netD,
        optimizer_G,
        optimizer_D,
        lr_scheduler_G,
        lr_scheduler_D,
        path=os.path.join(results_dir, 'pth', 'latest.pth'))
    start_epoch = last_epoch + 1
else:
    # load sub G extracted from latest.pth
    g_path = os.path.join(subnet_model_path, 'epoch%d_netG.pth' % 199)
    netG.load_state_dict(torch.load(g_path))
    print('load G from %s' % g_path)
    # load full D directly from latest.pth
    d_path = os.path.join('cp_results', args.dataset, args.task,
                          args.base_model_str, 'pth', 'latest.pth')
    netD.load_state_dict(torch.load('initial_weights/netD_B_seed_1.pth.tar'))
    print('load D from %s' % 'initial_weights')
    start_epoch = 0
    best_FID = 1e9
    loss_G_lst, loss_G_perceptual_lst, loss_G_GAN_lst, loss_D_lst = [], [], [], []

# Dataset loader: img shape=(256,256)
dataset_dir = os.path.join(foreign_dir, 'datasets', args.dataset)
soft_data_dir = os.path.join(foreign_dir, 'train_set_result', args.dataset)
transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]  # (0,1) -> (-1,1)
dataloader = DataLoader(PairedImageDataset(dataset_dir,
                                           soft_data_dir,
                                           transforms_=transforms_,
Ejemplo n.º 4
0
        netD,
        optimizer_G,
        optimizer_D,
        lr_scheduler_G,
        lr_scheduler_D,
        path=os.path.join(results_dir, 'pth', 'latest.pth'))
    start_epoch = last_epoch + 1
else:
    # load sub G extracted from latest.pth
    g_path = os.path.join(subnet_model_path, 'epoch%d_netG.pth' % 199)
    netG.load_state_dict(torch.load(g_path))
    print('load G from %s' % g_path)
    # load full D directly from latest.pth
    d_path = os.path.join('results', args.dataset, args.task,
                          args.base_model_str, 'pth', 'latest.pth')
    netD.load_state_dict(torch.load(d_path)['netD'])
    print('load D from %s' % d_path)
    start_epoch = 0
    best_FID = 1e9
    loss_G_lst, loss_G_perceptual_lst, loss_G_GAN_lst, loss_D_lst = [], [], [], []

# Dataset loader: img shape=(256,256)
dataset_dir = os.path.join(foreign_dir, 'datasets', args.dataset)
soft_data_dir = os.path.join(foreign_dir, 'train_set_result', args.dataset)
transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]  # (0,1) -> (-1,1)
dataloader = DataLoader(PairedImageDataset(dataset_dir,
                                           soft_data_dir,
                                           transforms_=transforms_,
Ejemplo n.º 5
0
        '''
        for param in model.trans.parameters():
            param.requires_grad = False
        for param in model.atmos.parameters():
            param.requires_grad = False
        '''
    except Exception as e:
        try:
            model.load_state_dict(torch.load(sys.argv[2]))
        except Exception as e:
            print("No weights. Training from scratch.")
    if MODE == 'GAN':
        model_d = Discriminator().to(device)
        optimizer_d = torch.optim.Adam(model_d.parameters(), lr=learning_rate)
        try:
            model_d.load_state_dict(torch.load(sys.argv[3]))
            if opt['parallel']:
                model_d = nn.DataParallel(model_d)
        except Exception as e:
            print("No weights. Training from scratch discrim.")
else:
    print('MODE INCORRECT : TRANS or ATMOS or FAST or DUAL or GAN')
    exit()

# Wrap in Data Parallel for multi-GPU use
if opt['parallel']:
    model = nn.DataParallel(model)

# Set default early stop, if not defined
if not 'early_stop' in opt:
    opt['early_stop'] = 100
Ejemplo n.º 6
0
        lr_scheduler_G,
        lr_scheduler_D,
        lr_scheduler_gamma,
        path=os.path.join(results_dir, 'pth', 'latest.pth'))
    start_epoch = last_epoch + 1
else:
    if args.dataset == 'horse2zebra':
        dense_model_folder = 'pretrained_dense_model_quant' if args.quant else 'pretrained_dense_model'
        g_path = os.path.join(foreign_dir, dense_model_folder, args.dataset,
                              'pth', 'netG_%s_epoch_%d.pth' % (args.task, 199))
        netG.load_state_dict(torch.load(g_path))
        print('load G from %s' % g_path)
        d_path = os.path.join(foreign_dir, dense_model_folder, args.dataset,
                              'pth',
                              'netD_%s_epoch_%d.pth' % (target_str, 199))
        netD.load_state_dict(torch.load(d_path))
        print('load D from %s' % d_path)
    start_epoch = 0
    loss_G_lst, loss_G_perceptual_lst, loss_G_GAN_lst, loss_D_lst, channel_number_lst = [], [], [], [], []

# Dataset loader: image shape=(256,256)
dataset_dir = os.path.join(foreign_dir, 'datasets', args.dataset)
soft_data_dir = os.path.join(foreign_dir, 'train_set_result', args.dataset)
transforms_ = [
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]  # (0,1) -> (-1,1)
dataloader = DataLoader(PairedImageDataset(dataset_dir,
                                           soft_data_dir,
                                           transforms_=transforms_,
                                           mode=args.task),