Ejemplo n.º 1
0
from models.discriminator import Discriminator
from models.generator import Generator


def generate_real_data() -> torch.float:
    real_data = torch.FloatTensor(
        [
            random.uniform(0.8, 1.0),
            random.uniform(0.0, 0.2),
            random.uniform(0.8, 1.0),
            random.uniform(0.0, 0.2),
        ])

    return real_data


def generate_random(size):
    random_data = torch.rand(size)
    return random_data


D = Discriminator()
G = Generator()

start = datetime.now()
for i in range(10000):
    D.train(generate_real_data(), torch.FloatTensor([1.0]))
    D.train(G.forward(torch.FloatTensor([0.5])).detach(), torch.FloatTensor([0.0]))
    G.train(D, torch.FloatTensor([0.5]), torch.FloatTensor([1.0]))
print(f'took {datetime.now().second - start.second} seconds')
Ejemplo n.º 2
0
def main():
    env = DialogEnvironment()
    experiment_name = args.logdir.split('/')[1] #model name

    torch.manual_seed(args.seed)

    #TODO
    actor = Actor(hidden_size=args.hidden_size,num_layers=args.num_layers,device='cuda',input_size=args.input_size,output_size=args.input_size)
    critic = Critic(hidden_size=args.hidden_size,num_layers=args.num_layers,input_size=args.input_size,seq_len=args.seq_len)
    discrim = Discriminator(hidden_size=args.hidden_size,num_layers=args.hidden_size,input_size=args.input_size,seq_len=args.seq_len)
    
    actor.to(device), critic.to(device), discrim.to(device)
    
    actor_optim = optim.Adam(actor.parameters(), lr=args.learning_rate)
    critic_optim = optim.Adam(critic.parameters(), lr=args.learning_rate, 
                              weight_decay=args.l2_rate) 
    discrim_optim = optim.Adam(discrim.parameters(), lr=args.learning_rate)

    # load demonstrations

    writer = SummaryWriter(args.logdir)

    if args.load_model is not None: #TODO
        saved_ckpt_path = os.path.join(os.getcwd(), 'save_model', str(args.load_model))
        ckpt = torch.load(saved_ckpt_path)

        actor.load_state_dict(ckpt['actor'])
        critic.load_state_dict(ckpt['critic'])
        discrim.load_state_dict(ckpt['discrim'])


    
    episodes = 0
    train_discrim_flag = True

    for iter in range(args.max_iter_num):
        actor.eval(), critic.eval()
        memory = deque()

        steps = 0
        scores = []
        similarity_scores = []
        while steps < args.total_sample_size: 
            scores = []
            similarity_scores = []
            state, expert_action, raw_state, raw_expert_action = env.reset()
            score = 0
            similarity_score = 0
            state = state[:args.seq_len,:]
            expert_action = expert_action[:args.seq_len,:]
            state = state.to(device)
            expert_action = expert_action.to(device)
            for _ in range(10000): 

                steps += 1

                mu, std = actor(state.resize(1,args.seq_len,args.input_size)) #TODO: gotta be a better way to resize. 
                action = get_action(mu.cpu(), std.cpu())[0]
                for i in range(5):
                    emb_sum = expert_action[i,:].sum().cpu().item()
                    if emb_sum == 0:
                       # print(i)
                        action[i:,:] = 0 # manual padding
                        break

                done= env.step(action)
                irl_reward = get_reward(discrim, state, action, args)
                if done:
                    mask = 0
                else:
                    mask = 1


                memory.append([state, torch.from_numpy(action).to(device), irl_reward, mask,expert_action])
                score += irl_reward
                similarity_score += get_cosine_sim(expert=expert_action,action=action.squeeze(),seq_len=5)
                #print(get_cosine_sim(s1=expert_action,s2=action.squeeze(),seq_len=5),'sim')
                if done:
                    break

            episodes += 1
            scores.append(score)
            similarity_scores.append(similarity_score)

        score_avg = np.mean(scores)
        similarity_score_avg = np.mean(similarity_scores)
        print('{}:: {} episode score is {:.2f}'.format(iter, episodes, score_avg))
        print('{}:: {} episode similarity score is {:.2f}'.format(iter, episodes, similarity_score_avg))

        actor.train(), critic.train(), discrim.train()
        if train_discrim_flag:
            expert_acc, learner_acc = train_discrim(discrim, memory, discrim_optim, args) 
            print("Expert: %.2f%% | Learner: %.2f%%" % (expert_acc * 100, learner_acc * 100))
            writer.add_scalar('log/expert_acc', float(expert_acc), iter) #logg
            writer.add_scalar('log/learner_acc', float(learner_acc), iter) #logg
            writer.add_scalar('log/avg_acc', float(learner_acc + expert_acc)/2, iter) #logg
            if args.suspend_accu_exp is not None: #only if not None do we check.
                if expert_acc > args.suspend_accu_exp and learner_acc > args.suspend_accu_gen:
                    train_discrim_flag = False

        train_actor_critic(actor, critic, memory, actor_optim, critic_optim, args)
        writer.add_scalar('log/score', float(score_avg), iter)
        writer.add_scalar('log/similarity_score', float(similarity_score_avg), iter)
        writer.add_text('log/raw_state', raw_state[0],iter)
        raw_action = get_raw_action(action) #TODO
        writer.add_text('log/raw_action', raw_action,iter)
        writer.add_text('log/raw_expert_action', raw_expert_action,iter)

        if iter % 100:
            score_avg = int(score_avg)
            # Open a file with access mode 'a'
            file_object = open(experiment_name+'.txt', 'a')

            result_str = str(iter) + '|' + raw_state[0] + '|' + raw_action + '|' + raw_expert_action + '\n'
            # Append at the end of file
            file_object.write(result_str)
            # Close the file
            file_object.close()

            model_path = os.path.join(os.getcwd(),'save_model')
            if not os.path.isdir(model_path):
                os.makedirs(model_path)

            ckpt_path = os.path.join(model_path, experiment_name + '_ckpt_'+ str(score_avg)+'.pth.tar')

            save_checkpoint({
                'actor': actor.state_dict(),
                'critic': critic.state_dict(),
                'discrim': discrim.state_dict(),
                'args': args,
                'score': score_avg,
            }, filename=ckpt_path)
Ejemplo n.º 3
0
class GanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(GanTrainer, self).__init__(train_loader, test_loader,
                                         valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the auto-encoder
        self.use_autoencoder = False
        if trainer_args.autoencoder_path and os.path.exists(
                trainer_args.autoencoder_path):
            self.use_autoencoder = True
            self.autoencoder = AutoEncoder(general_args=general_args).to(
                self.device)
            self.load_pretrained_autoencoder(trainer_args.autoencoder_path)
            self.autoencoder.eval()

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.adversarial_criterion = nn.BCEWithLogitsLoss()
        self.generator_time_criterion = nn.MSELoss()
        self.generator_frequency_criterion = nn.MSELoss()
        self.generator_autoencoder_criterion = nn.MSELoss()

        # Define labels
        self.real_label = 1
        self.generated_label = 0

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_freq = trainer_args.lambda_freq
        self.lambda_autoencoder = trainer_args.lambda_autoencoder

        # Spectrogram converter
        self.spectrogram = Spectrogram(normalized=True).to(self.device)

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Boolean if the generator receives the feedback from the discriminator
        self.use_adversarial = trainer_args.use_adversarial

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def load_pretrained_autoencoder(self, autoencoder_path):
        """
        Loads a pre-trained auto-encoder. Can be used to infer
        :param autoencoder_path: location of the pre-trained auto-encoder (string).
        :return: None
        """
        checkpoint = torch.load(autoencoder_path, map_location=self.device)
        self.autoencoder.load_state_dict(checkpoint['autoencoder_state_dict'])

    def train(self, epochs):
        """
        Trains the GAN for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                self.generator.train()
                self.discriminator.train()
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)
                batch_size = input_batch.shape[0]

                ############################
                # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
                ###########################
                # Train the discriminator with real data
                self.discriminator_optimizer.zero_grad()
                label = torch.full((batch_size, ),
                                   self.real_label,
                                   device=self.device)
                output = self.discriminator(target_batch)

                # Compute and store the discriminator loss on real data
                loss_discriminator_real = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['real'].append(
                    loss_discriminator_real.item())
                loss_discriminator_real.backward()

                # Train the discriminator with fake data
                generated_batch = self.generator(input_batch)
                label.fill_(self.generated_label)
                output = self.discriminator(generated_batch.detach())

                # Compute and store the discriminator loss on fake data
                loss_discriminator_generated = self.adversarial_criterion(
                    output, torch.unsqueeze(label, dim=1))
                self.train_losses['discriminator_adversarial']['fake'].append(
                    loss_discriminator_generated.item())
                loss_discriminator_generated.backward()

                # Update the discriminator weights
                self.discriminator_optimizer.step()

                ############################
                # Update G network: maximize log(D(G(z)))
                ###########################
                self.generator_optimizer.zero_grad()

                # Get the spectrogram
                specgram_target_batch = self.spectrogram(target_batch)
                specgram_fake_batch = self.spectrogram(generated_batch)

                # Fake labels are real for the generator cost
                label.fill_(self.real_label)
                output = self.discriminator(generated_batch)

                # Compute the generator loss on fake data
                # Get the adversarial loss
                loss_generator_adversarial = torch.zeros(size=[1],
                                                         device=self.device)
                if self.use_adversarial:
                    loss_generator_adversarial = self.adversarial_criterion(
                        output, torch.unsqueeze(label, dim=1))
                self.train_losses['generator_adversarial'].append(
                    loss_generator_adversarial.item())

                # Get the L2 loss in time domain
                loss_generator_time = self.generator_time_criterion(
                    generated_batch, target_batch)
                self.train_losses['time_l2'].append(loss_generator_time.item())

                # Get the L2 loss in frequency domain
                loss_generator_frequency = self.generator_frequency_criterion(
                    specgram_fake_batch, specgram_target_batch)
                self.train_losses['freq_l2'].append(
                    loss_generator_frequency.item())

                # Get the L2 loss in embedding space
                loss_generator_autoencoder = torch.zeros(size=[1],
                                                         device=self.device,
                                                         requires_grad=True)
                if self.use_autoencoder:
                    # Get the embeddings
                    _, embedding_target_batch = self.autoencoder(target_batch)
                    _, embedding_generated_batch = self.autoencoder(
                        generated_batch)
                    loss_generator_autoencoder = self.generator_autoencoder_criterion(
                        embedding_generated_batch, embedding_target_batch)
                    self.train_losses['autoencoder_l2'].append(
                        loss_generator_autoencoder.item())

                # Combine the different losses
                loss_generator = self.lambda_adv * loss_generator_adversarial + loss_generator_time + \
                                 self.lambda_freq * loss_generator_frequency + \
                                 self.lambda_autoencoder * loss_generator_autoencoder

                # Back-propagate and update the generator weights
                loss_generator.backward()
                self.generator_optimizer.step()

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Frequency: {} \n' \
                              '\t\t Autoencoder {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Real {} \n' \
                              '\t\t Fake {} \n'.format(i,
                                                       loss_generator_time.item(),
                                                       loss_generator_frequency.item(),
                                                       loss_generator_autoencoder.item(),
                                                       loss_generator_adversarial.item(),
                                                       loss_discriminator_real.item(),
                                                       loss_discriminator_generated.item())
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()
            # if self.need_saving:
            #     self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': [], 'freq_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            # Get the spectrogram
            specgram_target_batch = self.spectrogram(target_batch)
            specgram_generated_batch = self.spectrogram(generated_batch)

            loss_generator_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_generator_time.item())
            loss_generator_frequency = self.generator_frequency_criterion(
                specgram_generated_batch, specgram_target_batch)
            batch_losses['freq_l2'].append(loss_generator_frequency.item())

        # Store the validation losses
        self.valid_losses['time_l2'].append(np.mean(batch_losses['time_l2']))
        self.valid_losses['freq_l2'].append(np.mean(batch_losses['freq_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n' \
                  '\t Frequency: {} \n'.format(self.epoch,
                                               np.mean(np.mean(batch_losses['time_l2'])),
                                               np.mean(np.mean(batch_losses['freq_l2'])))
        print(message)

        # Check if the loss is decreasing
        self.check_improvement()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, self.savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        self.discriminator.load_state_dict(
            checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        self.discriminator_optimizer.load_state_dict(
            checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
Ejemplo n.º 4
0
class Model(object):
    def __init__(self, opt):
        super(Model, self).__init__()

        # Generator
        self.gen = Generator(opt).cuda(opt.gpu_id)

        self.gen_params = self.gen.parameters()

        num_params = 0
        for p in self.gen.parameters():
            num_params += p.numel()
        print(self.gen)
        print(num_params)

        # Discriminator
        self.dis = Discriminator(opt).cuda(opt.gpu_id)

        self.dis_params = self.dis.parameters()

        num_params = 0
        for p in self.dis.parameters():
            num_params += p.numel()
        print(self.dis)
        print(num_params)

        # Regressor
        if opt.mse_weight:
            self.reg = torch.load('data/utils/classifier.pth').cuda(
                opt.gpu_id).eval()
        else:
            self.reg = None

        # Losses
        self.criterion_gan = GANLoss(opt, self.dis)
        self.criterion_mse = lambda x, y: l1_loss(x, y) * opt.mse_weight

        self.loss_mse = Variable(torch.zeros(1).cuda())
        self.loss_adv = Variable(torch.zeros(1).cuda())
        self.loss = Variable(torch.zeros(1).cuda())

        self.path = opt.experiments_dir + opt.experiment_name + '/checkpoints/'
        self.gpu_id = opt.gpu_id
        self.noise_channels = opt.in_channels - len(opt.input_idx.split(','))

    def forward(self, inputs):

        input, input_orig, target = inputs

        self.input = Variable(input.cuda(self.gpu_id))
        self.input_orig = Variable(input_orig.cuda(self.gpu_id))
        self.target = Variable(target.cuda(self.gpu_id))

        noise = Variable(
            torch.randn(self.input.size(0),
                        self.noise_channels).cuda(self.gpu_id))

        self.fake = self.gen(torch.cat([self.input, noise], 1))

    def backward_G(self):

        # Regressor loss
        if self.reg is not None:

            fake_input = self.reg(self.fake)

            self.loss_mse = self.criterion_mse(fake_input, self.input_orig)

        # GAN loss
        loss_adv, _ = self.criterion_gan(self.fake)

        loss_G = self.loss_mse + loss_adv
        loss_G.backward()

    def backward_D(self):

        loss_adv, self.loss_adv = self.criterion_gan(self.target, self.fake)

        loss_D = loss_adv
        loss_D.backward()

    def train(self):

        self.gen.train()
        self.dis.train()

    def eval(self):

        self.gen.eval()
        self.dis.eval()

    def save_checkpoint(self, epoch):

        torch.save(
            {
                'epoch': epoch,
                'gen_state_dict': self.gen.state_dict(),
                'dis_state_dict': self.dis.state_dict()
            }, self.path + '%d.pkl' % epoch)

    def load_checkpoint(self, path, pretrained=True):

        weights = torch.load(path)

        self.gen.load_state_dict(weights['gen_state_dict'])
        self.dis.load_state_dict(weights['dis_state_dict'])
Ejemplo n.º 5
0
class WGanTrainer(Trainer):
    def __init__(self, train_loader, test_loader, valid_loader, general_args,
                 trainer_args):
        super(WGanTrainer, self).__init__(train_loader, test_loader,
                                          valid_loader, general_args)
        # Paths
        self.loadpath = trainer_args.loadpath
        self.savepath = trainer_args.savepath

        # Load the generator
        self.generator = Generator(general_args=general_args).to(self.device)
        if trainer_args.generator_path and os.path.exists(
                trainer_args.generator_path):
            self.load_pretrained_generator(trainer_args.generator_path)

        self.discriminator = Discriminator(general_args=general_args).to(
            self.device)

        # Optimizers and schedulers
        self.generator_optimizer = torch.optim.Adam(
            params=self.generator.parameters(), lr=trainer_args.generator_lr)
        self.discriminator_optimizer = torch.optim.Adam(
            params=self.discriminator.parameters(),
            lr=trainer_args.discriminator_lr)
        self.generator_scheduler = lr_scheduler.StepLR(
            optimizer=self.generator_optimizer,
            step_size=trainer_args.generator_scheduler_step,
            gamma=trainer_args.generator_scheduler_gamma)
        self.discriminator_scheduler = lr_scheduler.StepLR(
            optimizer=self.discriminator_optimizer,
            step_size=trainer_args.discriminator_scheduler_step,
            gamma=trainer_args.discriminator_scheduler_gamma)

        # Load saved states
        if os.path.exists(self.loadpath):
            self.load()

        # Loss function and stored losses
        self.generator_time_criterion = nn.MSELoss()

        # Loss scaling factors
        self.lambda_adv = trainer_args.lambda_adversarial
        self.lambda_time = trainer_args.lambda_time

        # Boolean indicating if the model needs to be saved
        self.need_saving = True

        # Overrides losses from parent class
        self.train_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.test_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }
        self.valid_losses = {
            'generator': {
                'time_l2': [],
                'adversarial': []
            },
            'discriminator': {
                'penalty': [],
                'adversarial': []
            }
        }

        # Select either wgan or wgan-gp method
        self.use_penalty = trainer_args.use_penalty
        self.gamma = trainer_args.gamma_wgan_gp
        self.clipping_limit = trainer_args.clipping_limit
        self.n_critic = trainer_args.n_critic
        self.coupling_epoch = trainer_args.coupling_epoch

    def load_pretrained_generator(self, generator_path):
        """
        Loads a pre-trained generator. Can be used to stabilize the training.
        :param generator_path: location of the pre-trained generator (string).
        :return: None
        """
        checkpoint = torch.load(generator_path, map_location=self.device)
        self.generator.load_state_dict(checkpoint['generator_state_dict'])

    def compute_gradient_penalty(self, input_batch, generated_batch):
        """
        Compute the gradient penalty as described in the original paper
        (https://papers.nips.cc/paper/7159-improved-training-of-wasserstein-gans.pdf).
        :param input_batch: batch of input data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: penalty as a scalar (torch tensor).
        """
        batch_size = input_batch.size(0)
        epsilon = torch.rand(batch_size, 1, 1)
        epsilon = epsilon.expand_as(input_batch).to(self.device)

        # Interpolate
        interpolation = epsilon * input_batch.data + (
            1 - epsilon) * generated_batch.data
        interpolation = interpolation.requires_grad_(True).to(self.device)

        # Computes the discriminator's prediction for the interpolated input
        interpolation_logits = self.discriminator(interpolation)

        # Computes a vector of outputs to make it works with 2 output classes if needed
        grad_outputs = torch.ones_like(interpolation_logits).to(
            self.device).requires_grad_(True)

        # Get the gradients and retain the graph so that the penalty can be back-propagated
        gradients = autograd.grad(outputs=interpolation_logits,
                                  inputs=interpolation,
                                  grad_outputs=grad_outputs,
                                  create_graph=True,
                                  retain_graph=True,
                                  only_inputs=True)[0]
        gradients = gradients.view(batch_size, -1)

        # Computes the norm of the gradients
        gradients_norm = torch.sqrt(torch.sum(gradients**2, dim=1))
        return ((gradients_norm - 1)**2).mean()

    def train_discriminator_step(self, input_batch, target_batch):
        """
        Trains the discriminator for a single step based on the wasserstein gan-gp framework.
        :param input_batch: batch of input data (torch tensor).
        :param target_batch: batch of target data (torch tensor).
        :return: a batch of generated data (torch tensor).
        """
        # Activate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=True)

        # Set the discriminator's gradients to zero
        self.discriminator_optimizer.zero_grad()

        # Generate a batch and compute the penalty
        generated_batch = self.generator(input_batch)

        # Compute the loss
        loss_d = self.discriminator(generated_batch.detach()).mean(
        ) - self.discriminator(target_batch).mean()
        self.train_losses['discriminator']['adversarial'].append(loss_d.item())
        if self.use_penalty:
            penalty = self.compute_gradient_penalty(input_batch,
                                                    generated_batch.detach())
            self.train_losses['discriminator']['penalty'].append(
                penalty.item())
            loss_d = loss_d + self.gamma * penalty

        # Update the discriminator's weights
        loss_d.backward()
        self.discriminator_optimizer.step()

        # Apply the weight constraint if needed
        if not self.use_penalty:
            for p in self.discriminator.parameters():
                p.data.clamp_(min=-self.clipping_limit,
                              max=self.clipping_limit)

        # Return the generated batch to avoid redundant computation
        return generated_batch

    def train_generator_step(self, target_batch, generated_batch):
        """
        Trains the generator for a single step based on the wasserstein gan-gp framework.
        :param target_batch: batch of target data (torch tensor).
        :param generated_batch: batch of generated data (torch tensor).
        :return: None
        """
        # Deactivate gradient tracking for the discriminator
        self.change_discriminator_grad_requirement(requires_grad=False)

        # Set generator's gradients to zero
        self.generator_optimizer.zero_grad()

        # Get the generator losses
        loss_g_adversarial = -self.discriminator(generated_batch).mean()
        loss_g_time = self.generator_time_criterion(generated_batch,
                                                    target_batch)

        # Combine the different losses
        loss_g = loss_g_time
        if self.epoch >= self.coupling_epoch:
            loss_g = loss_g + self.lambda_adv * loss_g_adversarial

        # Back-propagate and update the generator weights
        loss_g.backward()
        self.generator_optimizer.step()

        # Store the losses
        self.train_losses['generator']['time_l2'].append(loss_g_time.item())
        self.train_losses['generator']['adversarial'].append(
            loss_g_adversarial.item())

    def change_discriminator_grad_requirement(self, requires_grad):
        """
        Changes the requires_grad flag of discriminator's parameters. This action is not absolutely needed as the
        discriminator's optimizer is not called after the generators update, but it reduces the computational cost.
        :param requires_grad: flag indicating if the discriminator's parameter require gradient tracking (boolean).
        :return: None
        """
        for p in self.discriminator.parameters():
            p.requires_grad_(requires_grad)

    def train(self, epochs):
        """
        Trains the WGAN-GP for a given number of pseudo-epochs.
        :param epochs: Number of time to iterate over a part of the dataset (int).
        :return: None
        """
        self.generator.train()
        self.discriminator.train()
        for epoch in range(epochs):
            for i in range(self.train_batches_per_epoch):
                # Transfer to GPU
                local_batch = next(self.train_loader_iter)
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Train the discriminator
                generated_batch = self.train_discriminator_step(
                    input_batch, target_batch)

                # Train the generator every n_critic
                if not (i % self.n_critic):
                    self.train_generator_step(target_batch, generated_batch)

                # Print message
                if not (i % 10):
                    message = 'Batch {}: \n' \
                              '\t Generator: \n' \
                              '\t\t Time: {} \n' \
                              '\t\t Adversarial: {} \n' \
                              '\t Discriminator: \n' \
                              '\t\t Penalty: {}\n' \
                              '\t\t Adversarial: {} \n'.format(i,
                                                               self.train_losses['generator']['time_l2'][-1],
                                                               self.train_losses['generator']['adversarial'][-1],
                                                               self.train_losses['discriminator']['penalty'][-1],
                                                               self.train_losses['discriminator']['adversarial'][-1])
                    print(message)

            # Evaluate the model
            with torch.no_grad():
                self.eval()

            # Save the trainer state
            self.save()

            # Increment epoch counter
            self.epoch += 1
            self.generator_scheduler.step()
            self.discriminator_scheduler.step()

    def eval(self):
        # Set the models in evaluation mode
        self.generator.eval()
        self.discriminator.eval()
        batch_losses = {'time_l2': []}
        for i in range(self.valid_batches_per_epoch):
            # Transfer to GPU
            local_batch = next(self.valid_loader_iter)
            input_batch, target_batch = local_batch[0].to(
                self.device), local_batch[1].to(self.device)

            generated_batch = self.generator(input_batch)

            loss_g_time = self.generator_time_criterion(
                generated_batch, target_batch)
            batch_losses['time_l2'].append(loss_g_time.item())

        # Store the validation losses
        self.valid_losses['generator']['time_l2'].append(
            np.mean(batch_losses['time_l2']))

        # Display validation losses
        message = 'Epoch {}: \n' \
                  '\t Time: {} \n'.format(self.epoch, np.mean(np.mean(batch_losses['time_l2'])))
        print(message)

        # Set the models in train mode
        self.generator.train()
        self.discriminator.eval()

    def save(self):
        """
        Saves the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        savepath = self.savepath.split('.')[0] + '_' + str(
            self.epoch // 5) + '.' + self.savepath.split('.')[1]
        torch.save(
            {
                'epoch':
                self.epoch,
                'generator_state_dict':
                self.generator.state_dict(),
                'discriminator_state_dict':
                self.discriminator.state_dict(),
                'generator_optimizer_state_dict':
                self.generator_optimizer.state_dict(),
                'discriminator_optimizer_state_dict':
                self.discriminator_optimizer.state_dict(),
                'generator_scheduler_state_dict':
                self.generator_scheduler.state_dict(),
                'discriminator_scheduler_state_dict':
                self.discriminator_scheduler.state_dict(),
                'train_losses':
                self.train_losses,
                'test_losses':
                self.test_losses,
                'valid_losses':
                self.valid_losses
            }, savepath)

    def load(self):
        """
        Loads the model(s), optimizer(s), scheduler(s) and losses
        :return: None
        """
        checkpoint = torch.load(self.loadpath, map_location=self.device)
        self.epoch = checkpoint['epoch']
        self.generator.load_state_dict(checkpoint['generator_state_dict'])
        # self.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        self.generator_optimizer.load_state_dict(
            checkpoint['generator_optimizer_state_dict'])
        # self.discriminator_optimizer.load_state_dict(checkpoint['discriminator_optimizer_state_dict'])
        self.generator_scheduler.load_state_dict(
            checkpoint['generator_scheduler_state_dict'])
        self.discriminator_scheduler.load_state_dict(
            checkpoint['discriminator_scheduler_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.test_losses = checkpoint['test_losses']
        self.valid_losses = checkpoint['valid_losses']

    def evaluate_metrics(self, n_batches):
        """
        Evaluates the quality of the reconstruction with the SNR and LSD metrics on a specified number of batches
        :param: n_batches: number of batches to process
        :return: mean and std for each metric
        """
        with torch.no_grad():
            snrs = []
            lsds = []
            generator = self.generator.eval()
            for k in range(n_batches):
                # Transfer to GPU
                local_batch = next(self.test_loader_iter)
                # Transfer to GPU
                input_batch, target_batch = local_batch[0].to(
                    self.device), local_batch[1].to(self.device)

                # Generates a batch
                generated_batch = generator(input_batch)

                # Get the metrics
                snrs.append(
                    snr(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))
                lsds.append(
                    lsd(x=generated_batch.squeeze(),
                        x_ref=target_batch.squeeze()))

            snrs = torch.cat(snrs).cpu().numpy()
            lsds = torch.cat(lsds).cpu().numpy()

            # Some signals corresponding to silence will be all zeroes and cause troubles due to the logarithm
            snrs[np.isinf(snrs)] = np.nan
            lsds[np.isinf(lsds)] = np.nan
        return np.nanmean(snrs), np.nanstd(snrs), np.nanmean(lsds), np.nanstd(
            lsds)
Ejemplo n.º 6
0
class Seq2SeqCycleGAN:
    def __init__(self,
                 model_config,
                 train_config,
                 vocab,
                 max_len,
                 mode='train'):
        self.mode = mode

        self.model_config = model_config
        self.train_config = train_config

        self.vocab = vocab
        self.vocab_size = self.vocab.num_words
        self.max_len = max_len

        # self.embedding_layer = nn.Embedding(vocab_size, model_config['embedding_size'], padding_idx=PAD_token)
        self.embedding_layer = nn.Sequential(
            nn.Linear(self.vocab_size, self.model_config['embedding_size']),
            nn.Sigmoid())

        self.G_AtoB = Generator(self.embedding_layer,
                                self.model_config,
                                self.train_config,
                                self.vocab_size,
                                self.max_len,
                                mode=self.mode).cuda()
        self.G_BtoA = Generator(self.embedding_layer,
                                self.model_config,
                                self.train_config,
                                self.vocab_size,
                                self.max_len,
                                mode=self.mode).cuda()

        if self.mode == 'train':
            self.D_B = Discriminator(self.embedding_layer, self.model_config,
                                     self.train_config).cuda()
            self.D_A = Discriminator(self.embedding_layer, self.model_config,
                                     self.train_config).cuda()

            if self.train_config['continue_train']:
                self.embedding_layer.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_embedding_layer.pth'))
                self.G_AtoB.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_G_AtoB.pth'))
                self.G_BtoA.load_state_dict(
                    torch.load(self.train_config['which_epoch'] +
                               '_G_BtoA.pth'))
                self.D_B.load_state_dict(
                    torch.load(self.train_config['which_epoch'] + '_D_B.pth'))
                self.D_A.load_state_dict(
                    torch.load(self.train_config['which_epoch'] + '_D_A.pth'))

            self.embedding_layer.train()
            self.G_AtoB.train()
            self.G_BtoA.train()
            self.D_B.train()
            self.D_A.train()

            self.criterionBCE = nn.BCELoss().cuda()
            self.criterionCE = nn.CrossEntropyLoss().cuda()

            self.optimizer_G = torch.optim.Adam(itertools.chain(
                self.embedding_layer.parameters(), self.G_AtoB.parameters(),
                self.G_BtoA.parameters()),
                                                lr=train_config['base_lr'],
                                                betas=(0.9, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(
                self.embedding_layer.parameters(), self.D_A.parameters(),
                self.D_B.parameters()),
                                                lr=train_config['base_lr'],
                                                betas=(0.9, 0.999))

            self.real_label = torch.ones(
                (train_config['batch_size'], 1)).cuda()
            self.fake_label = torch.zeros(
                (train_config['batch_size'], 1)).cuda()
        else:
            self.embedding_layer.load_state_dict(
                torch.load(self.train_config['which_epoch'] +
                           '_embedding_layer.pth'))
            self.G_AtoB.load_state_dict(
                torch.load(self.train_config['which_epoch'] + '_G_AtoB.pth'))
            self.G_BtoA.load_state_dict(
                torch.load(self.train_config['which_epoch'] + '_G_BtoA.pth'))

            self.embedding_layer.eval()
            self.G_AtoB.eval()
            self.G_BtoA.eval()

    def backward_D_basic(self, netD, real, real_addn_feats, fake,
                         fake_addn_feats):
        netD.hidden = netD.init_hidden()
        pred_real = netD(real, real_addn_feats)
        loss_D_real = self.criterionBCE(pred_real, self.real_label)

        netD.hidden = netD.init_hidden()
        pred_fake = netD(fake.detach(), fake_addn_feats)
        loss_D_fake = self.criterionBCE(pred_fake, self.fake_label)

        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()

        self.clip_gradient(self.embedding_layer)
        self.clip_gradient(netD)

        return loss_D

    def backward_D_A(self):
        self.loss_D_A = self.backward_D_basic(
            self.D_A, self.real_A, self.real_A_addn_feats, self.fake_A,
            self.fake_A_addn_feats) * 10

    def backward_D_B(self):
        self.loss_D_B = self.backward_D_basic(
            self.D_B, self.real_B, self.real_B_addn_feats, self.fake_B,
            self.fake_B_addn_feats) * 10

    def backward_G(self):
        self.D_B.hidden = self.D_B.init_hidden()
        self.fake_B_addn_feats = get_addn_feats(self.fake_B, self.vocab).cuda()
        self.loss_G_AtoB = self.criterionBCE(
            self.D_B(self.fake_B, self.fake_B_addn_feats),
            self.real_label) * 10

        self.D_A.hidden = self.D_A.init_hidden()
        self.fake_A_addn_feats = get_addn_feats(self.fake_A, self.vocab).cuda()
        self.loss_G_BtoA = self.criterionBCE(
            self.D_A(self.fake_A, self.fake_A_addn_feats),
            self.real_label) * 10

        if self.rec_A.size(0) != self.real_A_label.size(0):
            self.real_A, self.rec_A, self.real_A_label = self.update_label_sizes(
                self.real_A, self.rec_A, self.real_A_label)
        self.loss_cycle_A = self.criterionCE(self.rec_A,
                                             self.real_A_label)  #* lambda_A

        if self.rec_B.size(0) != self.real_B_label.size(0):
            self.real_B, self.rec_B, self.real_B_label = self.update_label_sizes(
                self.real_B, self.rec_B, self.real_B_label)
        self.loss_cycle_B = self.criterionCE(self.rec_B,
                                             self.real_B_label)  #* lambda_B

        self.idt_B = self.G_AtoB(self.real_B)
        if self.idt_B.size(0) != self.real_B_label.size(0):
            self.real_B, self.idt_B, self.real_B_label = self.update_label_sizes(
                self.real_B, self.idt_B, self.real_B_label)
        self.loss_idt_B = self.criterionCE(
            self.idt_B, self.real_B_label)  #* lambda_B * lambda_idt

        self.idt_A = self.G_BtoA(self.real_A)
        if self.idt_A.size(0) != self.real_A_label.size(0):
            self.real_A, self.idt_A, self.real_A_label = self.update_label_sizes(
                self.real_A, self.idt_A, self.real_A_label)
        self.loss_idt_A = self.criterionCE(
            self.idt_A, self.real_A_label)  #* lambda_A * lambda_idt

        self.loss_G = self.loss_G_AtoB + self.loss_G_BtoA + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

        self.clip_gradient(self.embedding_layer)
        self.clip_gradient(self.G_AtoB)
        self.clip_gradient(self.G_BtoA)

    def forward(self, real_A, real_A_addn_feats, real_B, real_B_addn_feats):
        self.real_A = real_A
        self.real_A_addn_feats = real_A_addn_feats
        self.real_A_label = self.real_A.max(dim=1)[1]

        self.real_B = real_B
        self.real_B_addn_feats = real_B_addn_feats
        self.real_B_label = self.real_B.max(dim=1)[1]

        self.fake_B = F.softmax(self.G_AtoB.forward(self.real_A), dim=1)
        self.fake_A = F.softmax(self.G_BtoA.forward(self.real_B), dim=1)

        if self.mode == 'train':
            self.rec_A = self.G_BtoA.forward(self.fake_B)
            self.rec_B = self.G_AtoB.forward(self.fake_A)

        else:
            real_A_list = self.real_A.max(dim=1)[1].tolist()
            real_B_list = self.real_B.max(dim=1)[1].tolist()

            fake_B_list = self.fake_B.max(dim=1)[1].tolist()
            fake_A_list = self.fake_A.max(dim=1)[1].tolist()

            print('Input (Shakespeare):', idx_to_sent(real_A_list, self.vocab))
            print('Output (Modern):', idx_to_sent(fake_B_list, self.vocab))
            print('\n')
            print('Input (Modern):', idx_to_sent(real_B_list, self.vocab))
            print('Output (Shakespeare):', idx_to_sent(fake_A_list,
                                                       self.vocab))
            print('\n')

    def optimize_parameters(self):
        self.set_requires_grad([self.D_A, self.D_B], False)
        self.optimizer_G.zero_grad()
        self.backward_G()
        self.optimizer_G.step()

        self.set_requires_grad([self.D_A, self.D_B], True)
        self.optimizer_D.zero_grad()
        self.backward_D_B()
        self.backward_D_A()
        self.optimizer_D.step()

    def update_label_sizes(self, real, rec, real_label):

        if rec.size(0) > real.size(0):
            real_label = torch.cat(
                (real_label, torch.zeros((rec.size(0) - real.size(0))).type(
                    torch.LongTensor).cuda()), 0)
        elif rec.size(0) < real.size(0):
            diff = real.size(0) - rec.size(0)
            to_concat = torch.zeros((diff, self.vocab_size)).cuda()
            to_concat[:, 0] = 1
            rec = torch.cat((rec, to_concat), 0)

        return real, rec, real_label

    def indices_to_one_hot(self, idx_tensor):
        one_hot_tensor = torch.empty((idx_tensor.size(0), self.vocab_size))
        for idx in range(idx_tensor.size(0)):
            zeros = torch.zeros((self.vocab_size))
            zeros[idx_tensor[idx].item()] = 1.0
            one_hot_tensor[idx] = zeros

        return one_hot_tensor

    def set_requires_grad(self, nets, requires_grad=False):
        if not isinstance(nets, list):
            nets = [nets]
        for net in nets:
            if net is not None:
                for param in net.parameters():
                    param.requires_grad = requires_grad

    def clip_gradient(self, model):
        nn.utils.clip_grad_norm_(model.parameters(), 0.25)
Ejemplo n.º 7
0
def gen(model, input_data, optimizer, loss, batch):
    reset_grads()
    decode, mu, logvar, _ = model(input_data)
    loss_generator = loss(decode, input_data, mu, logvar, batch)
    loss_generator.backward()
    optimizer.step()


running_counter = 0

for epoch in range(1, args.epochs + 1):
    print('---- Epoch {} ----'.format(epoch))
    iteration = 0
    model_source.train()
    model_target.train()
    discriminator_model.train()
    source_iter = iter(train_loader_source)
    target_iter = iter(train_loader_target)
    # ---------- Train --------------
    while iteration < len(source_iter) and iteration < len(target_iter):
        running_counter += 1
        iteration += 1
        source_input, _ = source_iter.next()
        source_input = Variable(source_input)
        target_input, _ = target_iter.next()
        target_input = Variable(target_input)

        if args.cuda:
            source_input = source_input.cuda()
            target_input = target_input.cuda()
Ejemplo n.º 8
0
                batch_idx += 1
                if batch_idx % 500 == 0:
                    print(
                        f"Pre-train G epoch {epoch}, batch {batch_idx}, loss {loss}"
                    )
            print("\n")
        ckpt.save(sess)

        print("Pre-training Discriminator...")
        for epoch in range(30):
            batch_idx = 0
            for w1, w2, _, _, sent1, sent2 in get_batch(data,
                                                        FLAGS.batch_size,
                                                        shuffle=True):
                g_sent1, g_sent2, _, _ = G.generate(sess, w1, w2)
                loss = D.train(sess, sent1, sent2, g_sent1, g_sent2)
                batch_idx += 1
                if batch_idx % 500 == 0:
                    print(
                        f"Pre-train D epoch {epoch}, batch {batch_idx}, loss {loss}"
                    )
            print("\n")
        ckpt.save(sess)

        print("Adversarial Training...")
        for epoch in range(3):
            batch_idx = 0
            for w1, w2, x1, x2, sent1, sent2 in get_batch(data,
                                                          FLAGS.batch_size,
                                                          shuffle=True):
                g_sent1, g_sent2, gen_x1, gen_x2 = G.generate(sess, w1, w2)