def one_im_discrim(discrim_path, im_path):
    discriminator = Discriminator(3, 64)
    discriminator.load_state_dict(
        torch.load(discrim_path, map_location=torch.device('cpu')))
    discriminator.eval()

    tensor = transforms.ToTensor()
    im = torchImage.open(im_path)

    result = discriminator(tensor(Image.open(im_path))).view(-1)
    print(result.data.item())
def run_model(model_path, discrim_path):
    model = Deblurrer()
    model.load_state_dict(
        torch.load(model_path, map_location=torch.device('cpu')))
    model.eval()

    discriminator = Discriminator(3, 64)
    discriminator.load_state_dict(
        torch.load(discrim_path, map_location=torch.device('cpu')))
    discriminator.eval()

    dataset = LFWC(["../data/train/faces_blurred"], "../data/train/faces")
    #dataset = FakeData(size=1000, image_size=(3, 128, 128), transform=transforms.ToTensor())
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=1,
                                              shuffle=True)
    for data in data_loader:
        blurred_img = Variable(data['blurred'])
        nonblurred = Variable(data['nonblurred'])

        # Should be near zero
        discrim_output_blurred = discriminator(blurred_img).view(
            -1).data.item()
        # Should be naer one
        discrim_output_nonblurred = discriminator(nonblurred).view(
            -1).data.item()

        #im = Image.open(image_path)
        #transform = transforms.ToTensor()
        transformback = transforms.ToPILImage()
        plt.imshow(transformback(blurred_img[0]))
        plt.title('Blurred, Discrim value: ' + str(discrim_output_blurred))
        plt.show()
        plt.imshow(transformback(nonblurred[0]))
        plt.title('Non Blurred, Discrim value: ' +
                  str(discrim_output_nonblurred))
        plt.show()

        out = model(blurred_img)
        discrim_output_model = discriminator(out).view(-1).data.item()
        #print(out.shape)
        outIm = transformback(out[0])

        plt.imshow(outIm)
        plt.title('Model out, Discrim value: ' + str(discrim_output_model))
        plt.show()
class CycleGAN(AlignmentModel):
    """This class implements the alignment model for GAN networks with two generators and two discriminators
    (cycle GAN). For description of the implemented functions, refer to the alignment model."""
    def __init__(self,
                 device,
                 config,
                 generator_a=None,
                 generator_b=None,
                 discriminator_a=None,
                 discriminator_b=None):
        """Initialize two new generators and two discriminators from the config or use pre-trained ones and create Adam
        optimizers for all models."""
        super().__init__(device, config)
        self.epoch_losses = [0., 0., 0., 0.]

        if generator_a is None:
            generator_a_conf = dict(
                dim_1=config['dim_b'],
                dim_2=config['dim_a'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_a = Generator(generator_a_conf, device)
            self.generator_a.to(device)
        else:
            self.generator_a = generator_a
        if 'optimizer' in config:
            self.optimizer_g_a = OPTIMIZERS[config['optimizer']](
                self.generator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_a = OPTIMIZERS[config['optimizer_default']](
                    self.generator_a.parameters())
        else:
            self.optimizer_g_a = torch.optim.Adam(
                self.generator_a.parameters(), config['learning_rate'])

        if generator_b is None:
            generator_b_conf = dict(
                dim_1=config['dim_a'],
                dim_2=config['dim_b'],
                layer_number=config['generator_layers'],
                layer_expansion=config['generator_expansion'],
                initialize_generator=config['initialize_generator'],
                norm=config['gen_norm'],
                batch_norm=config['gen_batch_norm'],
                activation=config['gen_activation'],
                dropout=config['gen_dropout'])
            self.generator_b = Generator(generator_b_conf, device)
            self.generator_b.to(device)
        else:
            self.generator_b = generator_b
        if 'optimizer' in config:
            self.optimizer_g_b = OPTIMIZERS[config['optimizer']](
                self.generator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_g_b = OPTIMIZERS[config['optimizer_default']](
                    self.generator_b.parameters())
        else:
            self.optimizer_g_b = torch.optim.Adam(
                self.generator_b.parameters(), config['learning_rate'])

        if discriminator_a is None:
            discriminator_a_conf = dict(
                dim=config['dim_a'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_a = Discriminator(discriminator_a_conf, device)
            self.discriminator_a.to(device)
        else:
            self.discriminator_a = discriminator_a
        if 'optimizer' in config:
            self.optimizer_d_a = OPTIMIZERS[config['optimizer']](
                self.discriminator_a.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_a = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_a.parameters())
        else:
            self.optimizer_d_a = torch.optim.Adam(
                self.discriminator_a.parameters(), config['learning_rate'])

        if discriminator_b is None:
            discriminator_b_conf = dict(
                dim=config['dim_b'],
                layer_number=config['discriminator_layers'],
                layer_expansion=config['discriminator_expansion'],
                batch_norm=config['disc_batch_norm'],
                activation=config['disc_activation'],
                dropout=config['disc_dropout'])
            self.discriminator_b = Discriminator(discriminator_b_conf, device)
            self.discriminator_b.to(device)
        else:
            self.discriminator_b = discriminator_b
        if 'optimizer' in config:
            self.optimizer_d_b = OPTIMIZERS[config['optimizer']](
                self.discriminator_b.parameters(), config['learning_rate'])
        elif 'optimizer_default' in config:
            if config['optimizer_default'] == 'sgd':
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters(), config['learning_rate'])
            else:
                self.optimizer_d_b = OPTIMIZERS[config['optimizer_default']](
                    self.discriminator_b.parameters())
        else:
            self.optimizer_d_b = torch.optim.Adam(
                self.discriminator_b.parameters(), config['learning_rate'])

    def train(self):
        self.generator_a.train()
        self.generator_b.train()
        self.discriminator_a.train()
        self.discriminator_b.train()

    def eval(self):
        self.generator_a.eval()
        self.generator_b.eval()
        self.discriminator_a.eval()
        self.discriminator_b.eval()

    def zero_grad(self):
        self.optimizer_g_a.zero_grad()
        self.optimizer_g_b.zero_grad()
        self.optimizer_d_a.zero_grad()
        self.optimizer_d_b.zero_grad()

    def optimize_all(self):
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def optimize_generator(self):
        """Do the optimization step only for generators (e.g. when training generators and discriminators separately or
        in turns)."""
        self.optimizer_g_a.step()
        self.optimizer_g_b.step()

    def optimize_discriminator(self):
        """Do the optimization step only for discriminators (e.g. when training generators and discriminators separately
        or in turns)."""
        self.optimizer_d_a.step()
        self.optimizer_d_b.step()

    def change_lr(self, factor):
        self.current_lr = self.current_lr * factor
        for param_group in self.optimizer_g_a.param_groups:
            param_group['lr'] = self.current_lr
        for param_group in self.optimizer_g_b.param_groups:
            param_group['lr'] = self.current_lr

    def update_losses_batch(self, *losses):
        loss_g_a, loss_g_b, loss_d_a, loss_d_b = losses
        self.epoch_losses[0] += loss_g_a
        self.epoch_losses[1] += loss_g_b
        self.epoch_losses[2] += loss_d_a
        self.epoch_losses[3] += loss_d_b

    def complete_epoch(self, epoch_metrics):
        self.metrics.append(epoch_metrics + [sum(self.epoch_losses)])
        self.losses.append(self.epoch_losses)
        self.epoch_losses = [0., 0., 0., 0.]

    def print_epoch_info(self):
        print(
            f"{len(self.metrics)} ### {self.losses[-1][0]:.2f} - {self.losses[-1][1]:.2f} "
            f"- {self.losses[-1][2]:.2f} - {self.losses[-1][3]:.2f} ### {self.metrics[-1]}"
        )

    def copy_model(self):
        self.model_copy = deepcopy(self.generator_a.state_dict()), deepcopy(self.generator_b.state_dict()),\
                          deepcopy(self.discriminator_a.state_dict()), deepcopy(self.discriminator_b.state_dict())

    def restore_model(self):
        self.generator_a.load_state_dict(self.model_copy[0])
        self.generator_b.load_state_dict(self.model_copy[1])
        self.discriminator_a.load_state_dict(self.model_copy[2])
        self.discriminator_b.load_state_dict(self.model_copy[3])

    def export_model(self, test_results, description=None):
        if description is None:
            description = f"CycleGAN_{self.config['evaluation']}_{self.config['subset']}"
        export_cyclegan_alignment(description, self.config, self.generator_a,
                                  self.generator_b, self.discriminator_a,
                                  self.discriminator_b, self.metrics)
        save_alignment_test_results(test_results, description)
        print(f"Saved model to directory {description}.")

    @classmethod
    def load_model(cls, name, device):
        generator_a, generator_b, discriminator_a, discriminator_b, config = load_cyclegan_alignment(
            name, device)
        model = cls(device, config, generator_a, generator_b, discriminator_a,
                    discriminator_b)
        return model
Exemple #4
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    print("======printing args========")
    print(args)
    print("=================================")

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        print("Loading bin dataset")
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    else:
        print(f"Loading raw text dataset {args.data}")
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
        #args.data, splits, args.src_lang, args.trg_lang)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst
    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))
    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    # try to load generator model
    g_model_path = 'checkpoints/generator/best_gmodel.pt'
    if not os.path.exists(g_model_path):
        print("Start training generator!")
        train_g(args, dataset)
    assert os.path.exists(g_model_path)
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load(g_model_path)
    #print(f"First dict: {pretrained_dict}")
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    #print(f"Second dict: {pretrained_dict}")
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    #print(f"model dict: {model_dict}")
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    print("Generator has successfully loaded!")

    # try to load discriminator model
    d_model_path = 'checkpoints/discriminator/best_dmodel.pt'
    if not os.path.exists(d_model_path):
        print("Start training discriminator!")
        train_d(args, dataset)
    assert os.path.exists(d_model_path)
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    model_dict = discriminator.state_dict()
    pretrained_dict = torch.load(d_model_path)
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(model_dict)

    print("Discriminator has successfully loaded!")

    #return
    print("starting main training loop")

    torch.autograd.set_detect_anomaly(True)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/joint'):
        os.makedirs('checkpoints/joint')
    checkpoints_path = 'checkpoints/joint/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(size_average=False,
                                   ignore_index=dataset.dst_dict.pad(),
                                   reduce=True)
    d_criterion = torch.nn.BCEWithLogitsLoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        # seed = args.seed + epoch_i
        # torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        itr = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate,
                             args.lr_shrink, g_optimizer)

        for i, sample in enumerate(itr):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when rand > 50%
            rand = random.random()
            if rand >= 0.5:
                # policy gradient training
                generator.decoder.is_testing = True
                sys_out_batch, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
                with torch.no_grad():
                    n_i = sample['net_input']['src_tokens']
                    #print(f"net input:\n{n_i}, pred: \n{prediction}")
                    reward = discriminator(
                        sample['net_input']['src_tokens'],
                        prediction)  # dataset.dst_dict.pad())
                train_trg_batch = sample['target']
                #print(f"sys_out_batch: {sys_out_batch.shape}:\n{sys_out_batch}")
                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                # logging.debug("G policy gradient loss at batch {0}: {1:.3f}, lr={2}".format(i, pg_loss.item(), g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()

                # oracle valid
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
            else:
                # MLE training
                #print(f"printing sample: \n{sample}")
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    "G MLE loss at batch {0}: {1:.3f}, lr={2}".format(
                        i, g_logging_meters['train_loss'].avg,
                        g_optimizer.param_groups[0]['lr']))
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm(generator.parameters(),
                                              args.clip_norm)
                g_optimizer.step()
            num_update += 1

            # part II: train the discriminator
            bsz = sample['target'].size(0)
            src_sentence = sample['net_input']['src_tokens']
            # train with half human-translation and half machine translation

            true_sentence = sample['target']
            true_labels = Variable(
                torch.ones(sample['target'].size(0)).float())

            with torch.no_grad():
                generator.decoder.is_testing = True
                _, prediction, _ = generator(sample)
                generator.decoder.is_testing = False
            fake_sentence = prediction
            fake_labels = Variable(
                torch.zeros(sample['target'].size(0)).float())

            trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
            labels = torch.cat([true_labels, fake_labels], dim=0)

            indices = np.random.permutation(2 * bsz)
            trg_sentence = trg_sentence[indices][:bsz]
            labels = labels[indices][:bsz]

            if use_cuda:
                labels = labels.cuda()

            disc_out = discriminator(src_sentence,
                                     trg_sentence)  #, dataset.dst_dict.pad())
            #print(f"disc out: {disc_out.shape}, labels: {labels.shape}")
            #print(f"labels: {labels}")
            d_loss = d_criterion(disc_out, labels.long())
            acc = torch.sum(torch.Sigmoid()
                            (disc_out).round() == labels).float() / len(labels)
            d_logging_meters['train_acc'].update(acc)
            d_logging_meters['train_loss'].update(d_loss)
            # logging.debug("D training loss {0:.3f}, acc {1:.3f} at batch {2}: ".format(d_logging_meters['train_loss'].avg,
            #                                                                            d_logging_meters['train_acc'].avg,
            #                                                                            i))
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        itr = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(itr):
            with torch.no_grad():
                if use_cuda:
                    sample['id'] = sample['id'].cuda()
                    sample['net_input']['src_tokens'] = sample['net_input'][
                        'src_tokens'].cuda()
                    sample['net_input']['src_lengths'] = sample['net_input'][
                        'src_lengths'].cuda()
                    sample['net_input']['prev_output_tokens'] = sample[
                        'net_input']['prev_output_tokens'].cuda()
                    sample['target'] = sample['target'].cuda()

                # generator validation
                _, _, loss = generator(sample)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                loss = loss / sample_size / math.log(2)
                g_logging_meters['valid_loss'].update(loss, sample_size)
                logging.debug("G dev loss at batch {0}: {1:.3f}".format(
                    i, g_logging_meters['valid_loss'].avg))

                # discriminator validation
                bsz = sample['target'].size(0)
                src_sentence = sample['net_input']['src_tokens']
                # train with half human-translation and half machine translation

                true_sentence = sample['target']
                true_labels = Variable(
                    torch.ones(sample['target'].size(0)).float())

                with torch.no_grad():
                    generator.decoder.is_testing = True
                    _, prediction, _ = generator(sample)
                    generator.decoder.is_testing = False
                fake_sentence = prediction
                fake_labels = Variable(
                    torch.zeros(sample['target'].size(0)).float())

                trg_sentence = torch.cat([true_sentence, fake_sentence], dim=0)
                labels = torch.cat([true_labels, fake_labels], dim=0)

                indices = np.random.permutation(2 * bsz)
                trg_sentence = trg_sentence[indices][:bsz]
                labels = labels[indices][:bsz]

                if use_cuda:
                    labels = labels.cuda()

                disc_out = discriminator(src_sentence, trg_sentence,
                                         dataset.dst_dict.pad())
                d_loss = d_criterion(disc_out, labels)
                acc = torch.sum(torch.Sigmoid()(disc_out).round() ==
                                labels).float() / len(labels)
                d_logging_meters['valid_acc'].update(acc)
                d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug("D dev loss {0:.3f}, acc {1:.3f} at batch {2}".format(d_logging_meters['valid_loss'].avg,
                #                                                                     d_logging_meters['valid_acc'].avg, i))

        torch.save(generator,
                   open(
                       checkpoints_path + "joint_{0:.3f}.epoch_{1}.pt".format(
                           g_logging_meters['valid_loss'].avg, epoch_i), 'wb'),
                   pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator,
                       open(checkpoints_path + "best_gmodel.pt", 'wb'),
                       pickle_module=dill)
Exemple #5
0
class SVM_Classifier:
    def __init__(self, batch_size, image_size=64):
        self.image_size = image_size
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        self.save_filename = f'model_{datetime.datetime.now().strftime("%a_%H_%M")}.sav'

        transform = transforms.Compose([
            # transforms.Resize(self.image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ])

        self.trainset = torchvision.datasets.CIFAR10(root='./data',
                                                     train=True,
                                                     download=True,
                                                     transform=transform)
        self.trainloader = data.DataLoader(self.trainset,
                                           batch_size=batch_size,
                                           shuffle=True,
                                           num_workers=2)

        self.testset = torchvision.datasets.CIFAR10(root='./data',
                                                    train=False,
                                                    download=True,
                                                    transform=transform)
        self.testloader = data.DataLoader(self.testset,
                                          batch_size=batch_size,
                                          shuffle=False,
                                          num_workers=2)

        saved_state = torch.load(
            "C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\imagenet\\trained_model_Tue_17_06.pth"
        )
        self.discriminator = Discriminator(ngpu=1,
                                           num_channels=3,
                                           num_features=64,
                                           data_generation_mode=1,
                                           input_size=image_size)
        self.discriminator.load_state_dict(saved_state['discriminator'])
        self.discriminator.eval()  # change the mode of the network.

    def plot_training_data(self):
        # Plot some training images
        real_batch = next(iter(self.trainloader))
        real_batch = real_batch[0][0:8]
        plt.figure(figsize=(8, 8))
        plt.axis("off")
        plt.title("Training Images")
        plt.imshow(
            np.transpose(
                vutils.make_grid(real_batch[0].to(self.device)[:64],
                                 padding=2,
                                 normalize=True).cpu(), (1, 2, 0)))
        plt.show()

    def train(self):
        train_data, train_labels = next(iter(self.trainloader))
        modified_train_data = self.discriminator(train_data)
        l2_svm = svm.LinearSVC(verbose=2, max_iter=2000)

        modified_train_data_ndarray = modified_train_data.detach().numpy()
        train_labels_ndarray = train_labels.detach().numpy()
        self.l2_svm = l2_svm.fit(modified_train_data_ndarray,
                                 train_labels_ndarray)

        # save model
        with open(self.save_filename, 'wb') as file:
            pickle.dump(self.l2_svm, file)

    def train_test_SGD_Classifier(self):
        est = make_pipeline(StandardScaler(), SGDClassifier(max_iter=200))
        training_data = self.discriminator(next(iter(self.trainloader))[0])
        training_data = training_data.detach().numpy()
        est.steps[0][1].fit(training_data)

        self.est = est

        for i, data in enumerate(self.trainloader):
            train_data, train_labels = data
            modified_train_data = self.discriminator(train_data)

            modified_train_data_ndarray = modified_train_data.detach().numpy()
            train_labels_ndarray = train_labels.detach().numpy()
            modified_train_data_ndarray = est.steps[0][1].transform(
                modified_train_data_ndarray)

            est.steps[1][1].partial_fit(
                modified_train_data_ndarray,
                train_labels_ndarray,
                classes=np.unique(train_labels_ndarray))
            print(f'Batch: {i}')

        with open(self.save_filename, 'wb') as file:
            pickle.dump(est.steps[1][1], file)

    def test(self):
        l2_svm = self.est.steps[1][1]
        accuracy = []

        for i, data in enumerate(self.testloader):
            test_data, test_labels = data
            modified_test_data = self.discriminator(test_data)

            modified_test_data_ndarray = modified_test_data.detach().numpy()
            test_labels_ndarray = test_labels.detach().numpy()
            modified_test_data_ndarray = self.est.steps[0][1].transform(
                modified_test_data_ndarray)

            predictions = l2_svm.predict(modified_test_data_ndarray)

            accuracy.append(
                metrics.accuracy_score(test_labels_ndarray, predictions))

        print(f'Accuracy: {np.mean(accuracy)}')
Exemple #6
0
class trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        self.OldLabel_generator = U_Net(in_ch=cfg.DATASET.N_CLASS,
                                        out_ch=cfg.DATASET.N_CLASS,
                                        side='out')
        self.Image_generator = U_Net(in_ch=3,
                                     out_ch=cfg.DATASET.N_CLASS,
                                     side='in')
        self.discriminator = Discriminator(cfg.DATASET.N_CLASS + 3,
                                           cfg.DATASET.IMGSIZE,
                                           patch=True)

        self.criterion_G = GeneratorLoss(cfg.LOSS.LOSS_WEIGHT[0],
                                         cfg.LOSS.LOSS_WEIGHT[1],
                                         cfg.LOSS.LOSS_WEIGHT[2],
                                         ignore_index=cfg.LOSS.IGNORE_INDEX)
        self.criterion_D = DiscriminatorLoss()

        train_dataset = BaseDataset(cfg, split='train')
        valid_dataset = BaseDataset(cfg, split='val')
        self.train_dataloader = data.DataLoader(
            train_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)
        self.valid_dataloader = data.DataLoader(
            valid_dataset,
            batch_size=cfg.DATASET.BATCHSIZE,
            num_workers=8,
            shuffle=True,
            drop_last=True)

        self.ckpt_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints')
        if not os.path.isdir(self.ckpt_outdir):
            os.mkdir(self.ckpt_outdir)
        self.val_outdir = os.path.join(cfg.TRAIN.OUTDIR, 'val')
        if not os.path.isdir(self.val_outdir):
            os.mkdir(self.val_outdir)
        self.start_epoch = cfg.TRAIN.RESUME
        self.n_epoch = cfg.TRAIN.N_EPOCH

        self.optimizer_G = torch.optim.Adam(
            [{
                'params': self.OldLabel_generator.parameters()
            }, {
                'params': self.Image_generator.parameters()
            }],
            lr=cfg.OPTIMIZER.G_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        self.optimizer_D = torch.optim.Adam(
            [{
                'params': self.discriminator.parameters(),
                'initial_lr': cfg.OPTIMIZER.D_LR
            }],
            lr=cfg.OPTIMIZER.D_LR,
            betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            # betas=(cfg.OPTIMIZER.BETA1, cfg.OPTIMIZER.BETA2),
            weight_decay=cfg.OPTIMIZER.WEIGHT_DECAY)

        iter_per_epoch = len(train_dataset) // cfg.DATASET.BATCHSIZE
        lambda_poly = lambda iters: pow(
            (1.0 - iters / (cfg.TRAIN.N_EPOCH * iter_per_epoch)), 0.9)
        self.scheduler_G = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_G,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)
        self.scheduler_D = torch.optim.lr_scheduler.LambdaLR(
            self.optimizer_D,
            lr_lambda=lambda_poly,
        )
        # last_epoch=(self.start_epoch+1)*iter_per_epoch)

        self.logger = logger(cfg.TRAIN.OUTDIR, name='train')
        self.running_metrics = runningScore(n_classes=cfg.DATASET.N_CLASS)

        if self.start_epoch >= 0:
            self.OldLabel_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_N'])
            self.Image_generator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_G_I'])
            self.discriminator.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['model_D'])
            self.optimizer_G.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_G'])
            self.optimizer_D.load_state_dict(
                torch.load(
                    os.path.join(cfg.TRAIN.OUTDIR, 'checkpoints',
                                 '{}epoch.pth'.format(
                                     self.start_epoch)))['optimizer_D'])

            log = "Using the {}th checkpoint".format(self.start_epoch)
            self.logger.info(log)
        self.Image_generator = self.Image_generator.cuda()
        self.OldLabel_generator = self.OldLabel_generator.cuda()
        self.discriminator = self.discriminator.cuda()
        self.criterion_G = self.criterion_G.cuda()
        self.criterion_D = self.criterion_D.cuda()

    def train(self):
        all_train_iter_total_loss = []
        all_train_iter_corr_loss = []
        all_train_iter_recover_loss = []
        all_train_iter_change_loss = []
        all_train_iter_gan_loss_gen = []
        all_train_iter_gan_loss_dis = []
        all_val_epo_iou = []
        all_val_epo_acc = []
        iter_num = [0]
        epoch_num = []
        num_batches = len(self.train_dataloader)

        for epoch_i in range(self.start_epoch + 1, self.n_epoch):
            iter_total_loss = AverageTracker()
            iter_corr_loss = AverageTracker()
            iter_recover_loss = AverageTracker()
            iter_change_loss = AverageTracker()
            iter_gan_loss_gen = AverageTracker()
            iter_gan_loss_dis = AverageTracker()
            batch_time = AverageTracker()
            tic = time.time()

            # train
            self.OldLabel_generator.train()
            self.Image_generator.train()
            self.discriminator.train()
            for i, meta in enumerate(self.train_dataloader):

                image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                ), meta[2].cuda()
                recover_pred, feats = self.OldLabel_generator(
                    label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                corr_pred = self.Image_generator(image, feats)

                # -------------------
                # Train Discriminator
                # -------------------
                self.discriminator.set_requires_grad(True)
                self.optimizer_D.zero_grad()

                fake_sample = torch.cat((image, corr_pred), 1).detach()
                real_sample = torch.cat(
                    (image, label2onehot(new_label, cfg.DATASET.N_CLASS)), 1)

                score_fake_d = self.discriminator(fake_sample)
                score_real = self.discriminator(real_sample)

                gan_loss_dis = self.criterion_D(pred_score=score_fake_d,
                                                real_score=score_real)
                gan_loss_dis.backward()
                self.optimizer_D.step()
                self.scheduler_D.step()

                # ---------------
                # Train Generator
                # ---------------
                self.discriminator.set_requires_grad(False)
                self.optimizer_G.zero_grad()

                score_fake = self.discriminator(
                    torch.cat((image, corr_pred), 1))

                total_loss, corr_loss, recover_loss, change_loss, gan_loss_gen = self.criterion_G(
                    corr_pred, recover_pred, score_fake, old_label, new_label)

                total_loss.backward()
                self.optimizer_G.step()
                self.scheduler_G.step()

                iter_total_loss.update(total_loss.item())
                iter_corr_loss.update(corr_loss.item())
                iter_recover_loss.update(recover_loss.item())
                iter_change_loss.update(change_loss.item())
                iter_gan_loss_gen.update(gan_loss_gen.item())
                iter_gan_loss_dis.update(gan_loss_dis.item())
                batch_time.update(time.time() - tic)
                tic = time.time()

                log = '{}: Epoch: [{}][{}/{}], Time: {:.2f}, ' \
                      'Total Loss: {:.6f}, Corr Loss: {:.6f}, Recover Loss: {:.6f}, Change Loss: {:.6f}, GAN_G Loss: {:.6f}, GAN_D Loss: {:.6f}'.format(
                    datetime.now(), epoch_i, i, num_batches, batch_time.avg,
                    total_loss.item(), corr_loss.item(), recover_loss.item(), change_loss.item(), gan_loss_gen.item(), gan_loss_dis.item())
                print(log)

                if (i + 1) % 10 == 0:
                    all_train_iter_total_loss.append(iter_total_loss.avg)
                    all_train_iter_corr_loss.append(iter_corr_loss.avg)
                    all_train_iter_recover_loss.append(iter_recover_loss.avg)
                    all_train_iter_change_loss.append(iter_change_loss.avg)
                    all_train_iter_gan_loss_gen.append(iter_gan_loss_gen.avg)
                    all_train_iter_gan_loss_dis.append(iter_gan_loss_dis.avg)
                    iter_total_loss.reset()
                    iter_corr_loss.reset()
                    iter_recover_loss.reset()
                    iter_change_loss.reset()
                    iter_gan_loss_gen.reset()
                    iter_gan_loss_dis.reset()

                    vis.line(X=np.column_stack(
                        np.repeat(np.expand_dims(iter_num, 0), 6, axis=0)),
                             Y=np.column_stack((all_train_iter_total_loss,
                                                all_train_iter_corr_loss,
                                                all_train_iter_recover_loss,
                                                all_train_iter_change_loss,
                                                all_train_iter_gan_loss_gen,
                                                all_train_iter_gan_loss_dis)),
                             opts={
                                 'legend': [
                                     'total_loss', 'corr_loss', 'recover_loss',
                                     'change_loss', 'gan_loss_gen',
                                     'gan_loss_dis'
                                 ],
                                 'linecolor':
                                 np.array([[255, 0, 0], [0, 255, 0],
                                           [0, 0, 255], [255, 255, 0],
                                           [0, 255, 255], [255, 0, 255]]),
                                 'title':
                                 'Train loss of generator and discriminator'
                             },
                             win='Train loss of generator and discriminator')
                    iter_num.append(iter_num[-1] + 1)

            # eval
            self.OldLabel_generator.eval()
            self.Image_generator.eval()
            self.discriminator.eval()
            with torch.no_grad():
                for j, meta in enumerate(self.valid_dataloader):
                    image, old_label, new_label = meta[0].cuda(), meta[1].cuda(
                    ), meta[2].cuda()
                    recover_pred, feats = self.OldLabel_generator(
                        label2onehot(old_label, self.cfg.DATASET.N_CLASS))
                    corr_pred = self.Image_generator(image, feats)
                    preds = np.argmax(corr_pred.cpu().detach().numpy().copy(),
                                      axis=1)
                    target = new_label.cpu().detach().numpy().copy()
                    self.running_metrics.update(target, preds)

                    if j == 0:
                        color_map1 = gen_color_map(preds[0, :]).astype(
                            np.uint8)
                        color_map2 = gen_color_map(preds[1, :]).astype(
                            np.uint8)
                        color_map = cv2.hconcat([color_map1, color_map2])
                        cv2.imwrite(
                            os.path.join(
                                self.val_outdir, '{}epoch*{}*{}.png'.format(
                                    epoch_i, meta[3][0], meta[3][1])),
                            color_map)

            score = self.running_metrics.get_scores()
            oa = score['Overall Acc: \t']
            precision = score['Precision: \t'][1]
            recall = score['Recall: \t'][1]
            iou = score['Class IoU: \t'][1]
            miou = score['Mean IoU: \t']
            self.running_metrics.reset()

            epoch_num.append(epoch_i)
            all_val_epo_acc.append(oa)
            all_val_epo_iou.append(miou)
            vis.line(X=np.column_stack(
                np.repeat(np.expand_dims(epoch_num, 0), 2, axis=0)),
                     Y=np.column_stack((all_val_epo_acc, all_val_epo_iou)),
                     opts={
                         'legend':
                         ['val epoch Overall Acc', 'val epoch Mean IoU'],
                         'linecolor': np.array([[255, 0, 0], [0, 255, 0]]),
                         'title': 'Validate Accuracy and IoU'
                     },
                     win='validate Accuracy and IoU')

            log = '{}: Epoch Val: [{}], ACC: {:.2f}, Recall: {:.2f}, mIoU: {:.4f}' \
                .format(datetime.now(), epoch_i, oa, recall, miou)
            self.logger.info(log)

            state = {
                'epoch': epoch_i,
                "acc": oa,
                "recall": recall,
                "iou": miou,
                'model_G_N': self.OldLabel_generator.state_dict(),
                'model_G_I': self.Image_generator.state_dict(),
                'model_D': self.discriminator.state_dict(),
                'optimizer_G': self.optimizer_G.state_dict(),
                'optimizer_D': self.optimizer_D.state_dict()
            }
            save_path = os.path.join(self.cfg.TRAIN.OUTDIR, 'checkpoints',
                                     '{}epoch.pth'.format(epoch_i))
            torch.save(state, save_path)
Exemple #7
0
class GAIL:
    def __init__(self,
                 exp_dir,
                 exp_thresh,
                 state_dim,
                 action_dim,
                 learn_rate,
                 betas,
                 _device,
                 _gamma,
                 load_weights=False):
        """
            exp_dir : directory containing the expert episodes
         exp_thresh : parameter to control number of episodes to load 
                      as expert based on returns (lower means more episodes)
          state_dim : dimesnion of state 
         action_dim : dimesnion of action
         learn_rate : learning rate for optimizer 
            _device : GPU or cpu
            _gamma  : discount factor
     _load_weights  : load weights from directory
        """

        # storing runtime device
        self.device = _device

        # discount factor
        self.gamma = _gamma

        # Expert trajectory
        self.expert = ExpertTrajectories(exp_dir, exp_thresh, gamma=self.gamma)

        # Defining the actor and its optimizer
        self.actor = ActorNetwork(state_dim).to(self.device)
        self.optim_actor = torch.optim.Adam(self.actor.parameters(),
                                            lr=learn_rate,
                                            betas=betas)

        # Defining the discriminator and its optimizer
        self.disc = Discriminator(state_dim, action_dim).to(self.device)
        self.optim_disc = torch.optim.Adam(self.disc.parameters(),
                                           lr=learn_rate,
                                           betas=betas)

        if not load_weights:
            self.actor.apply(init_weights)
            self.disc.apply(init_weights)
        else:
            self.load()

        # Loss function crtiterion
        self.criterion = torch.nn.BCELoss()

    def get_action(self, state):
        """
            obtain action for a given state using actor network 
        """
        state = torch.tensor(state, dtype=torch.float,
                             device=self.device).view(1, -1)
        return self.actor(state).cpu().data.numpy().flatten()

    def update(self, n_iter, batch_size=100):
        """
            train discriminator and actor for mini-batch
        """
        # memory to store
        disc_losses = np.zeros(n_iter, dtype=np.float)
        act_losses = np.zeros(n_iter, dtype=np.float)

        for i in range(n_iter):

            # Get expert state and actions batch
            exp_states, exp_actions = self.expert.sample(batch_size)
            exp_states = torch.FloatTensor(exp_states).to(self.device)
            exp_actions = torch.FloatTensor(exp_actions).to(self.device)

            # Get state, and actions using actor
            states, _ = self.expert.sample(batch_size)
            states = torch.FloatTensor(states).to(self.device)
            actions = self.actor(states)
            '''
                train the discriminator
            '''
            self.optim_disc.zero_grad()

            # label tensors
            exp_labels = torch.full((batch_size, 1), 1, device=self.device)
            policy_labels = torch.full((batch_size, 1), 0, device=self.device)

            # with expert transitions
            prob_exp = self.disc(exp_states, exp_actions)
            exp_loss = self.criterion(prob_exp, exp_labels)

            # with policy actor transitions
            prob_policy = self.disc(states, actions.detach())
            policy_loss = self.criterion(prob_policy, policy_labels)

            # use backprop
            disc_loss = exp_loss + policy_loss
            disc_losses[i] = disc_loss.mean().item()

            disc_loss.backward()
            self.optim_disc.step()
            '''
                train the actor
            '''
            self.optim_actor.zero_grad()
            loss_actor = -self.disc(states, actions)
            act_losses[i] = loss_actor.mean().detach().item()

            loss_actor.mean().backward()
            self.optim_actor.step()

        print("Finished training minibatch")

        return act_losses, disc_losses

    def save(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        torch.save(self.actor.state_dict(),
                   '{}/{}_actor.pth'.format(directory, name))
        torch.save(self.disc.state_dict(),
                   '{}/{}_discriminator.pth'.format(directory, name))

    def load(
            self,
            directory='/home/aman/Programming/RL-Project/Deterministic-GAIL/weights',
            name='GAIL'):
        print(os.getcwd())
        self.actor.load_state_dict(
            torch.load('{}/{}_actor.pth'.format(directory, name)))
        self.disc.load_state_dict(
            torch.load('{}/{}_discriminator.pth'.format(directory, name)))

    def set_mode(self, mode="train"):

        if mode == "train":
            self.actor.train()
            self.disc.train()
        else:
            self.actor.eval()
            self.disc.eval()
class Trainer:
    def __init__(self, corpus_data, *, params):
        self.fast_text = FastText(corpus_data.model).to(GPU)
        self.discriminator = Discriminator(
            params.emb_dim,
            n_layers=params.d_n_layers,
            n_units=params.d_n_units,
            drop_prob=params.d_drop_prob,
            drop_prob_input=params.d_drop_prob_input,
            leaky=params.d_leaky,
            batch_norm=params.d_bn).to(GPU)
        self.ft_optimizer = optim.SGD(self.fast_text.parameters(),
                                      lr=params.ft_lr)
        self.d_optimizer = optim.SGD(self.discriminator.parameters(),
                                     lr=params.d_lr,
                                     weight_decay=params.d_wd)
        self.a_optimizer = optim.SGD([{
            "params": self.fast_text.u.parameters()
        }, {
            "params": self.fast_text.v.parameters()
        }],
                                     lr=params.a_lr)
        self.smooth = params.smooth
        self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean")
        self.corpus_data_queue = _data_queue(corpus_data,
                                             n_threads=params.n_threads,
                                             n_sentences=params.n_sentences,
                                             batch_size=params.ft_bs)
        self.vocab_size = params.vocab_size
        self.d_bs = params.d_bs
        self.split = params.split
        self.align_output = params.align_output

    def fast_text_step(self):
        self.ft_optimizer.zero_grad()
        u_b, v_b = self.corpus_data_queue.__next__()
        s = self.fast_text(u_b, v_b)
        loss = FastText.loss_fn(s)
        loss.backward()
        self.ft_optimizer.step()
        return loss.item()

    def get_adv_batch(self, *, reverse, fix_embedding):
        vocab_split, bs_split = int(self.vocab_size * self.split), int(
            self.d_bs * self.split)
        x = (torch.randint(0, vocab_split, size=(bs_split, ),
                           dtype=torch.long).tolist() +
             torch.randint(vocab_split,
                           self.vocab_size,
                           size=(self.d_bs - bs_split, ),
                           dtype=torch.long).tolist())
        if self.align_output:
            x = torch.LongTensor(x).view(self.d_bs, 1).to(GPU)
            if fix_embedding:
                with torch.no_grad():
                    x = self.fast_text.v(x).view(self.d_bs, -1)
            else:
                x = self.fast_text.v(x).view(self.d_bs, -1)
        else:
            x = self.fast_text.model.get_bag(x, self.fast_text.u.weight.device)
            if fix_embedding:
                with torch.no_grad():
                    x = self.fast_text.u(x[0], x[1]).view(self.d_bs, -1)
            else:
                x = self.fast_text.u(x[0], x[1]).view(self.d_bs, -1)
        y = torch.FloatTensor(self.d_bs).to(GPU).uniform_(0.0, self.smooth)
        if reverse:
            y[:bs_split] = 1 - y[:bs_split]
        else:
            y[bs_split:] = 1 - y[bs_split:]
        return x, y

    def discriminator_step(self):
        self.d_optimizer.zero_grad()
        self.discriminator.train()
        with torch.no_grad():
            x, y = self.get_adv_batch(reverse=False, fix_embedding=True)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        self.d_optimizer.step()
        return loss.item()

    def adversarial_step(self):
        self.a_optimizer.zero_grad()
        self.discriminator.eval()
        x, y = self.get_adv_batch(reverse=True, fix_embedding=False)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        self.a_optimizer.step()
        return loss.item()
Exemple #9
0
    """
    pixels = X.reshape((28, 28))
    plt.title(str(digit))
    plt.imshow(pixels, cmap='gray')
    plt.show()


X = get_normal_shaped_arrays(60000, (1, 784))

X_train, y_train, X_test, y_test = discriminator_train_test_set(
    X, X_train, params.DISCRIMINATOR_TRAIN_TEST_SPLIT)

discriminator = Discriminator(params.DISCRIMINATOR_BATCH_SIZE,
                              params.DISCRIMINATOR_EPOCHS)
discriminator.train(X_train, y_train)
print(discriminator.eval(X_test, y_test))

generator = Generator()

gan = Gan(generator, discriminator)
gan.set_discriminator_trainability(False)
gan.show_trainable()

X = get_normal_shaped_arrays(100000, (1, 16))
y = []
for _ in range(100000):
    y.append([0, 1])

y = np.array(y)

generator = gan.train_generator(X, y)
class Trainer:
    def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000):
        self.skip_gram = [SkipGram(corpus_data_0.vocab_size + 1, params.emb_dim).to(GPU),
                          SkipGram(corpus_data_1.vocab_size + 1, params.emb_dim).to(GPU)]
        self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units,
                                           drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input,
                                           leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU)
        self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False)
        self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
        self.mapping = self.mapping.to(GPU)
        self.sg_optimizer, self.sg_scheduler = [], []
        for id in [0, 1]:
            optimizer, scheduler = optimizers.get_sgd_adapt(self.skip_gram[id].parameters(),
                                                            lr=params.sg_lr, mode="max")
            self.sg_optimizer.append(optimizer)
            self.sg_scheduler.append(scheduler)
        self.a_optimizer, self.a_scheduler = [], []
        for id in [0, 1]:
            optimizer, scheduler = optimizers.get_sgd_adapt(
                [{"params": self.skip_gram[id].u.parameters()}, {"params": self.skip_gram[id].v.parameters()}],
                lr=params.a_lr, mode="max")
            self.a_optimizer.append(optimizer)
            self.a_scheduler.append(scheduler)
        if params.d_optimizer == "SGD":
            self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(),
                                                                          lr=params.d_lr, mode="max", wd=params.d_wd)

        elif params.d_optimizer == "RMSProp":
            self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(),
                                                                               params.n_steps,
                                                                               lr=params.d_lr, wd=params.d_wd)
        else:
            raise Exception(f"Optimizer {params.d_optimizer} not found.")
        if params.m_optimizer == "SGD":
            self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(),
                                                                          lr=params.m_lr, mode="max", wd=params.m_wd)
        elif params.m_optimizer == "RMSProp":
            self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(),
                                                                               params.n_steps,
                                                                               lr=params.m_lr, wd=params.m_wd)
        else:
            raise Exception(f"Optimizer {params.m_optimizer} not found")
        self.m_beta = params.m_beta
        self.smooth = params.smooth
        self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean")
        self.corpus_data_queue = [
            _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences,
                        batch_size=params.sg_bs),
            _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences,
                        batch_size=params.sg_bs)
        ]
        self.sampler = [
            WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top),
            WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)]
        self.d_bs = params.d_bs

    def skip_gram_step(self):
        losses = []
        for id in [0, 1]:
            self.sg_optimizer[id].zero_grad()
            pos_u_b, pos_v_b, neg_v_b = self.corpus_data_queue[id].__next__()
            pos_s, neg_s = self.skip_gram[id](pos_u_b, pos_v_b, neg_v_b)
            loss = SkipGram.loss_fn(pos_s, neg_s)
            loss.backward()
            self.sg_optimizer[id].step()
            losses.append(loss.item())
        return losses[0], losses[1]

    def get_adv_batch(self, *, reverse, fix_embedding=False):
        batch = [torch.LongTensor([self.sampler[id].sample() for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU)
                 for id in [0, 1]]
        if fix_embedding:
            with torch.no_grad():
                x = [self.skip_gram[id].u(batch[id]).view(self.d_bs, -1) for id in [0, 1]]
        else:
            x = [self.skip_gram[id].u(batch[id]).view(self.d_bs, -1) for id in [0, 1]]
        x[0] = self.mapping(x[0])
        x = torch.cat(x, 0)
        y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth)
        if reverse:
            y[: self.d_bs] = 1 - y[: self.d_bs]
        else:
            y[self.d_bs:] = 1 - y[self.d_bs:]
        return x, y

    def adversarial_step(self, fix_embedding=False):
        for id in [0, 1]:
            self.a_optimizer[id].zero_grad()
        self.m_optimizer.zero_grad()
        self.discriminator.eval()
        x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        for id in [0, 1]:
            self.a_optimizer[id].step()
        self.m_optimizer.step()
        _orthogonalize(self.mapping, self.m_beta)
        return loss.item()

    def discriminator_step(self):
        self.d_optimizer.zero_grad()
        self.discriminator.train()
        with torch.no_grad():
            x, y = self.get_adv_batch(reverse=False)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        self.d_optimizer.step()
        return loss.item()

    def scheduler_step(self, metric):
        for id in [0, 1]:
            self.sg_scheduler[id].step(metric)
            self.a_scheduler[id].step(metric)
        # self.d_scheduler.step(metric)
        self.m_scheduler.step(metric)
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(
            args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(
            args.data, splits, args.src_lang, args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src, len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst, len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split, len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0.3
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0.3
    args.bidirectional = False

    generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator_h = Discriminator_h(args.decoder_embed_dim, args.discriminator_hidden_size, args.discriminator_linear_size, args.discriminator_lin_dropout, use_cuda=use_cuda)
    print("Discriminator_h loaded successfully!")
    discriminator_s = Discriminator_s(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    print("Discriminator_s loaded successfully!")

    def _calcualte_discriminator_loss(tf_scores, ar_scores):
        tf_loss = torch.log(tf_scores + 1e-9) * (-1)
        ar_loss = torch.log(1 - ar_scores + 1e-9) * (-1)
        return tf_loss + ar_loss

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator_h = torch.nn.DataParallel(discriminator_h).cuda()
            discriminator_s = torch.nn.DataParallel(discriminator_s).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator_h.cuda()
            discriminator_s.cuda()
    else:
        discriminator_h.cpu()
        discriminator_s.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/professor2'):
        os.makedirs('checkpoints/professor2')
    checkpoints_path = 'checkpoints/professor2/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(), reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(), size_average=True, reduce=True)

    # fix discriminator_h word embedding (as Wu et al. do)
    for p in discriminator_s.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator_s.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(lambda x: x.requires_grad,
                                                                 generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer_h = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad,
                                                                 discriminator_h.parameters()),
                                                          args.d_learning_rate,
                                                          momentum=args.momentum,
                                                          nesterov=True)

    d_optimizer_s = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad,
                                                                 discriminator_s.parameters()),
                                                          args.d_learning_rate,
                                                          momentum=args.momentum,
                                                          nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        # set training mode
        generator.train()
        discriminator_h.train()
        discriminator_s.train()
        update_learning_rate(num_update, 8e4, args.g_learning_rate, args.lr_shrink, g_optimizer)

        for i, sample in enumerate(trainloader):

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator
            # print("Policy Gradient Training")
            sys_out_batch_PG, p_PG, hidden_list_PG = generator('PG', epoch_i, sample)  # 64 X 50 X 6632

            out_batch_PG = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 * 50) X 6632

            _, prediction = out_batch_PG.topk(1)
            prediction = prediction.squeeze(1)  # 64*50 = 3200
            prediction = torch.reshape(prediction, sample['net_input']['src_tokens'].shape)  # 64 X 50

            with torch.no_grad():
                reward = discriminator_s(sample['net_input']['src_tokens'], prediction)  # 64 X 1

            train_trg_batch_PG = sample['target']  # 64 x 50

            pg_loss_PG = pg_criterion(sys_out_batch_PG, train_trg_batch_PG, reward, use_cuda)
            sample_size_PG = sample['target'].size(0) if args.sentence_avg else sample['ntokens']  # 64
            logging_loss_PG = pg_loss_PG / math.log(2)
            g_logging_meters['train_loss'].update(logging_loss_PG.item(), sample_size_PG)
            logging.debug(
                f"G policy gradient loss at batch {i}: {pg_loss_PG.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}")
            g_optimizer.zero_grad()
            pg_loss_PG.backward(retain_graph=True)
            torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm)
            g_optimizer.step()

            # print("MLE Training")
            sys_out_batch_MLE, p_MLE, hidden_list_MLE = generator("MLE", epoch_i, sample)

            out_batch_MLE = sys_out_batch_MLE.contiguous().view(-1, sys_out_batch_MLE.size(-1))  # (64 X 50) X 6632

            train_trg_batch_MLE = sample['target'].view(-1)  # 64*50 = 3200
            loss_MLE = g_criterion(out_batch_MLE, train_trg_batch_MLE)

            sample_size_MLE = sample['target'].size(0) if args.sentence_avg else sample['ntokens']
            nsentences = sample['target'].size(0)
            logging_loss_MLE = loss_MLE.data / sample_size_MLE / math.log(2)
            g_logging_meters['bsz'].update(nsentences)
            g_logging_meters['train_loss'].update(logging_loss_MLE, sample_size_MLE)
            logging.debug(
                f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}")
            g_optimizer.zero_grad()
            loss_MLE.backward(retain_graph=True)
            # all-reduce grads and rescale by grad_denom
            for p in generator.parameters():
                # print(p.size())
                if p.requires_grad:
                    p.grad.data.div_(sample_size_MLE)
            torch.nn.utils.clip_grad_norm_(generator.parameters(), args.clip_norm)
            g_optimizer.step()

            num_update += 1


            #  part II: train the discriminator

            # discriminator_h
            if num_update % 5 == 0:

                d_MLE = discriminator_h(hidden_list_MLE)
                d_PG = discriminator_h(hidden_list_PG)
                d_loss = _calcualte_discriminator_loss(d_MLE, d_PG).sum()
                logging.debug(f"D_h training loss {d_loss} at batch {i}")

                d_optimizer_h.zero_grad()
                d_loss.backward()
                torch.nn.utils.clip_grad_norm_(discriminator_h.parameters(), args.clip_norm)
                d_optimizer_h.step()




                #discriminator_s
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input']['src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = torch.ones(sample['target'].size(0)).float()  # 64 length vector
                with torch.no_grad():
                    sys_out_batch, p, hidden_list = generator('MLE', epoch_i, sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(-1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = torch.zeros(sample['target'].size(0)).float()  # 64 length vector

                fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence, src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                fake_disc_out = discriminator_s(src_sentence, fake_sentence)  # 64 X 1
                true_disc_out = discriminator_s(src_sentence, true_sentence)

                fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)

                acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)

                d_loss = fake_d_loss + true_d_loss

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D_s training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}")
                d_optimizer_s.zero_grad()
                d_loss.backward()
                d_optimizer_s.step()

        # validation
        # set validation mode
        generator.eval()
        discriminator_h.eval()
        discriminator_s.eval()
        # Initialize dataloader
        max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
        valloader = dataset.eval_dataloader(
            'valid',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_valid,
            skip_invalid_size_inputs_valid_test=True,
            descending=True,  # largest batch first to warm the caching allocator
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(valloader):

            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=cuda)

                # generator validation
                sys_out_batch_test, p_test, hidden_list_test = generator('test', epoch_i, sample)
                out_batch_test = sys_out_batch_test.contiguous().view(-1, sys_out_batch_test.size(-1))  # (64 X 50) X 6632
                dev_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss_test = g_criterion(out_batch_test, dev_trg_batch)
                sample_size_test = sample['target'].size(0) if args.sentence_avg else sample['ntokens']
                loss_test = loss_test / sample_size_test / math.log(2)
                g_logging_meters['valid_loss'].update(loss_test, sample_size_test)
                logging.debug(f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}")

                # # discriminator_h validation
                # bsz = sample['target'].size(0)
                # src_sentence = sample['net_input']['src_tokens']
                # # train with half human-translation and half machine translation
                # true_sentence = sample['target']
                # true_labels = torch.ones(sample['target'].size(0)).float()
                # with torch.no_grad():
                #     sys_out_batch_PG, p, hidden_list = generator('test', epoch_i, sample)
                #
                # out_batch = sys_out_batch_PG.contiguous().view(-1, sys_out_batch_PG.size(-1))  # (64 X 50) X 6632
                # _, prediction = out_batch.topk(1)
                # prediction = prediction.squeeze(1)  # 64 * 50 = 6632
                # fake_labels = torch.zeros(sample['target'].size(0)).float()
                # fake_sentence = torch.reshape(prediction, src_sentence.shape)  # 64 X 50
                # if use_cuda:
                #     fake_labels = fake_labels.cuda()
                # disc_out = discriminator_h(src_sentence, fake_sentence)
                # d_loss = d_criterion(disc_out.squeeze(1), fake_labels)
                # acc = torch.sum(torch.round(disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # d_logging_meters['valid_acc'].update(acc)
                # d_logging_meters['valid_loss'].update(d_loss)
                # logging.debug(
                #     f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}")

        torch.save(generator,
                   open(checkpoints_path + f"sampling_{g_logging_meters['valid_loss'].avg:.3f}.epoch_{epoch_i}.pt",
                        'wb'), pickle_module=dill)

        if g_logging_meters['valid_loss'].avg < best_dev_loss:
            best_dev_loss = g_logging_meters['valid_loss'].avg
            torch.save(generator, open(checkpoints_path + "best_gmodel.pt", 'wb'), pickle_module=dill)
Exemple #12
0
from discriminator import Discriminator
import torchvision.datasets as dset
from torchvision import transforms
import torch.utils.data


if __name__ == "__main__":
    saved_state = torch.load("C:\\Users\\ankit\\Workspaces\\CS7150\\FinalProject\\models\\trained_model_Mon_05_45.pth")
    dis = Discriminator(ngpu=1, num_channels=3, num_features=64)
    dis.load_state_dict(saved_state['discriminator'])

    dis.eval()

    dataset = dset.ImageFolder(root="C:\\Users\\ankit\\Workspaces\\CS7150\\data\\imagenet",
                                    transform=transforms.Compose([
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                                    ]))
    # Create the dataloader
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=True, num_workers=2)
    images = next(iter(dataloader))
    out = dis(images[0])

    print()


def train_d(args, dataset):
    logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',
                        datefmt='%Y-%m-%d %H:%M:%S',
                        level=logging.DEBUG)

    use_cuda = (torch.cuda.device_count() >= 1)

    # check checkpoints saving path
    if not os.path.exists('checkpoints/discriminator'):
        os.makedirs('checkpoints/discriminator')

    checkpoints_path = 'checkpoints/discriminator/'

    logging_meters = OrderedDict()
    logging_meters['train_loss'] = AverageMeter()
    logging_meters['train_acc'] = AverageMeter()
    logging_meters['valid_loss'] = AverageMeter()
    logging_meters['valid_acc'] = AverageMeter()
    logging_meters['update_times'] = AverageMeter()

    # Build model
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)

    # Load generator
    assert os.path.exists('checkpoints/generator/best_gmodel.pt')
    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    model_dict = generator.state_dict()
    pretrained_dict = torch.load('checkpoints/generator/best_gmodel.pt')
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            # generator = torch.nn.DataParallel(generator).cuda()
            generator.cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    criterion = torch.nn.CrossEntropyLoss()

    # optimizer = eval("torch.optim." + args.d_optimizer)(filter(lambda x: x.requires_grad, discriminator.parameters()),
    #                                                     args.d_learning_rate, momentum=args.momentum, nesterov=True)

    optimizer = torch.optim.RMSprop(
        filter(lambda x: x.requires_grad, discriminator.parameters()), 1e-4)

    lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, patience=0, factor=args.lr_shrink)

    # Train until the accuracy achieve the define value
    max_epoch = args.max_epoch or math.inf
    epoch_i = 1
    trg_acc = 0.82
    best_dev_loss = math.inf
    lr = optimizer.param_groups[0]['lr']

    # validation set data loader (only prepare once)
    train = prepare_training_data(args, dataset, 'train', generator, epoch_i,
                                  use_cuda)
    valid = prepare_training_data(args, dataset, 'valid', generator, epoch_i,
                                  use_cuda)
    data_train = DatasetProcessing(data=train, maxlen=args.fixed_max_len)
    data_valid = DatasetProcessing(data=valid, maxlen=args.fixed_max_len)

    # main training loop
    while lr > args.min_d_lr and epoch_i <= max_epoch:
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        if args.sample_without_replacement > 0 and epoch_i > 1:
            train = prepare_training_data(args, dataset, 'train', generator,
                                          epoch_i, use_cuda)
            data_train = DatasetProcessing(data=train,
                                           maxlen=args.fixed_max_len)

        # discriminator training dataloader
        train_loader = train_dataloader(data_train,
                                        batch_size=args.joint_batch_size,
                                        seed=seed,
                                        epoch=epoch_i,
                                        sort_by_source_size=False)

        valid_loader = eval_dataloader(data_valid,
                                       num_workers=4,
                                       batch_size=args.joint_batch_size)

        # set training mode
        discriminator.train()

        # reset meters
        for key, val in logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(train_loader):
            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=use_cuda)

            disc_out = discriminator(sample['src_tokens'],
                                     sample['trg_tokens'])

            loss = criterion(disc_out, sample['labels'])
            _, prediction = F.softmax(disc_out, dim=1).topk(1)
            acc = torch.sum(
                prediction == sample['labels'].unsqueeze(1)).float() / len(
                    sample['labels'])

            logging_meters['train_acc'].update(acc.item())
            logging_meters['train_loss'].update(loss.item())
            logging.debug("D training loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                          format(logging_meters['train_loss'].avg, acc, logging_meters['train_acc'].avg,
                                 optimizer.param_groups[0]['lr'], i))

            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(discriminator.parameters(),
                                          args.clip_norm)
            optimizer.step()

            # del src_tokens, trg_tokens, loss, disc_out, labels, prediction, acc
            del disc_out, loss, prediction, acc

        # set validation mode
        discriminator.eval()

        for i, sample in enumerate(valid_loader):
            with torch.no_grad():
                if use_cuda:
                    # wrap input tensors in cuda tensors
                    sample = utils.make_variable(sample, cuda=use_cuda)

                disc_out = discriminator(sample['src_tokens'],
                                         sample['trg_tokens'])

                loss = criterion(disc_out, sample['labels'])
                _, prediction = F.softmax(disc_out, dim=1).topk(1)
                acc = torch.sum(
                    prediction == sample['labels'].unsqueeze(1)).float() / len(
                        sample['labels'])

                logging_meters['valid_acc'].update(acc.item())
                logging_meters['valid_loss'].update(loss.item())
                logging.debug("D eval loss {0:.3f}, acc {1:.3f}, avgAcc {2:.3f}, lr={3} at batch {4}: ". \
                              format(logging_meters['valid_loss'].avg, acc, logging_meters['valid_acc'].avg,
                                     optimizer.param_groups[0]['lr'], i))

            del disc_out, loss, prediction, acc

        lr_scheduler.step(logging_meters['valid_loss'].avg)

        if logging_meters['valid_acc'].avg >= 0.70:
            torch.save(discriminator.state_dict(), checkpoints_path + "ce_{0:.3f}_acc_{1:.3f}.epoch_{2}.pt" \
                       .format(logging_meters['valid_loss'].avg, logging_meters['valid_acc'].avg, epoch_i))

            if logging_meters['valid_loss'].avg < best_dev_loss:
                best_dev_loss = logging_meters['valid_loss'].avg
                torch.save(discriminator.state_dict(),
                           checkpoints_path + "best_dmodel.pt")

        # pretrain the discriminator to achieve accuracy 82%
        if logging_meters['valid_acc'].avg >= trg_acc:
            return

        epoch_i += 1
class SegPoseNet(nn.Module):
    def __init__(self, data_options):
        super(SegPoseNet, self).__init__()

        pose_arch_cfg = data_options['pose_arch_cfg']
        self.width = int(data_options['width'])
        self.height = int(data_options['height'])
        self.channels = int(data_options['channels'])
        self.domains = int(data_options['domains'])

        # note you need to change this after modifying the network
        self.output_h = 76
        self.output_w = 76

        self.coreModel = Darknet(pose_arch_cfg, self.width, self.height, self.channels, self.domains)
        self.segLayer = PoseSegLayer(data_options)
        self.regLayer = Pose2DLayer(data_options)
        self.discLayer = Discriminator()
        self.training = False

    def forward(self, x, y = None, adapt = False, domains = None):
        outlayers = self.coreModel(x, domains=domains)

        if self.training and adapt:
            in1 = source_only(outlayers[0], domains)
            in2 = source_only(outlayers[1], domains)
        else:
            in1 = outlayers[0]
            in2 = outlayers[1]

        out3 = self.discLayer(outlayers[2])
        out4 = outlayers[3]
        out5 = outlayers[4]

        out1 = self.segLayer(in1)
        out2 = self.regLayer(in2)

        out_preds = [out1, out2, out3, out4, out5]
        return out_preds

    def train(self):
        self.coreModel.train()
        self.segLayer.train()
        self.regLayer.train()
        self.discLayer.train()
        self.training = True

    def eval(self):
        self.coreModel.eval()
        self.segLayer.eval()
        self.regLayer.eval()
        self.discLayer.eval()
        self.training = False

    def print_network(self):
        self.coreModel.print_network()

    def load_weights(self, weightfile):
        self.coreModel.load_state_dict(torch.load(weightfile))

    def save_weights(self, weightfile):
        torch.save(self.coreModel.state_dict(), weightfile)
Exemple #15
0
        # Loss and accuracy within the current epoch.
        loss1.append(gradient_penalty.item())
        loss2.append(disc_fake_source.item())
        loss3.append(disc_real_source.item())
        loss4.append(disc_real_class.item())
        loss5.append(disc_fake_class.item())
        acc1.append(accuracy)

        if batch_idx % 50 == 0:
            print("[", epoch, batch_idx, "]", "%.2f" % np.mean(loss1),
                  "%.2f" % np.mean(loss2), "%.2f" % np.mean(loss3),
                  "%.2f" % np.mean(loss4), "%.2f" % np.mean(loss5),
                  "%.2f" % np.mean(acc1))

    # Test the model after every epoch.
    aD.eval()
    with torch.no_grad():
        test_accu = []
        for batch_idx, (X_test_batch, Y_test_batch) in enumerate(testloader):
            X_test_batch, Y_test_batch = Variable(
                X_test_batch).cuda(), Variable(Y_test_batch).cuda()

            with torch.no_grad():
                _, output = aD(X_test_batch)

            prediction = output.data.max(1)[1]  # first column has actual prob.
            accuracy = (float(prediction.eq(Y_test_batch.data).sum()) /
                        float(batch_size)) * 100.0
            test_accu.append(accuracy)
            accuracy_test = np.mean(test_accu)
Exemple #16
0
def main(args):
    use_cuda = (len(args.gpuid) >= 1)
    print("{0} GPU(s) are available".format(cuda.device_count()))

    # Load dataset
    splits = ['train', 'valid']
    if data.has_binary_files(args.data, splits):
        dataset = data.load_dataset(args.data, splits, args.src_lang,
                                    args.trg_lang, args.fixed_max_len)
    else:
        dataset = data.load_raw_text_dataset(args.data, splits, args.src_lang,
                                             args.trg_lang, args.fixed_max_len)
    if args.src_lang is None or args.trg_lang is None:
        # record inferred languages in args, so that it's saved in checkpoints
        args.src_lang, args.trg_lang = dataset.src, dataset.dst

    print('| [{}] dictionary: {} types'.format(dataset.src,
                                               len(dataset.src_dict)))
    print('| [{}] dictionary: {} types'.format(dataset.dst,
                                               len(dataset.dst_dict)))

    for split in splits:
        print('| {} {} {} examples'.format(args.data, split,
                                           len(dataset.splits[split])))

    g_logging_meters = OrderedDict()
    g_logging_meters['train_loss'] = AverageMeter()
    g_logging_meters['valid_loss'] = AverageMeter()
    g_logging_meters['train_acc'] = AverageMeter()
    g_logging_meters['valid_acc'] = AverageMeter()
    g_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    d_logging_meters = OrderedDict()
    d_logging_meters['train_loss'] = AverageMeter()
    d_logging_meters['valid_loss'] = AverageMeter()
    d_logging_meters['train_acc'] = AverageMeter()
    d_logging_meters['valid_acc'] = AverageMeter()
    d_logging_meters['bsz'] = AverageMeter()  # sentences per batch

    # Set model parameters
    args.encoder_embed_dim = 1000
    args.encoder_layers = 2  # 4
    args.encoder_dropout_out = 0
    args.decoder_embed_dim = 1000
    args.decoder_layers = 2  # 4
    args.decoder_out_embed_dim = 1000
    args.decoder_dropout_out = 0
    args.bidirectional = False

    generator = LSTMModel(args,
                          dataset.src_dict,
                          dataset.dst_dict,
                          use_cuda=use_cuda)
    print("Generator loaded successfully!")
    discriminator = Discriminator(args,
                                  dataset.src_dict,
                                  dataset.dst_dict,
                                  use_cuda=use_cuda)
    print("Discriminator loaded successfully!")

    g_model_path = 'checkpoints/zhenwarm/generator.pt'
    assert os.path.exists(g_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    model_dict = generator.state_dict()
    model = torch.load(g_model_path)
    pretrained_dict = model.state_dict()
    # 1. filter out unnecessary keys
    pretrained_dict = {
        k: v
        for k, v in pretrained_dict.items() if k in model_dict
    }
    # 2. overwrite entries in the existing state dict
    model_dict.update(pretrained_dict)
    # 3. load the new state dict
    generator.load_state_dict(model_dict)
    print("pre-trained Generator loaded successfully!")
    #
    # Load discriminator model
    d_model_path = 'checkpoints/zhenwarm/discri.pt'
    assert os.path.exists(d_model_path)
    # generator = LSTMModel(args, dataset.src_dict, dataset.dst_dict, use_cuda=use_cuda)
    d_model_dict = discriminator.state_dict()
    d_model = torch.load(d_model_path)
    d_pretrained_dict = d_model.state_dict()
    # 1. filter out unnecessary keys
    d_pretrained_dict = {
        k: v
        for k, v in d_pretrained_dict.items() if k in d_model_dict
    }
    # 2. overwrite entries in the existing state dict
    d_model_dict.update(d_pretrained_dict)
    # 3. load the new state dict
    discriminator.load_state_dict(d_model_dict)
    print("pre-trained Discriminator loaded successfully!")

    if use_cuda:
        if torch.cuda.device_count() > 1:
            discriminator = torch.nn.DataParallel(discriminator).cuda()
            generator = torch.nn.DataParallel(generator).cuda()
        else:
            generator.cuda()
            discriminator.cuda()
    else:
        discriminator.cpu()
        generator.cpu()

    # adversarial training checkpoints saving path
    if not os.path.exists('checkpoints/myzhencli5'):
        os.makedirs('checkpoints/myzhencli5')
    checkpoints_path = 'checkpoints/myzhencli5/'

    # define loss function
    g_criterion = torch.nn.NLLLoss(ignore_index=dataset.dst_dict.pad(),
                                   reduction='sum')
    d_criterion = torch.nn.BCELoss()
    pg_criterion = PGLoss(ignore_index=dataset.dst_dict.pad(),
                          size_average=True,
                          reduce=True)

    # fix discriminator word embedding (as Wu et al. do)
    for p in discriminator.embed_src_tokens.parameters():
        p.requires_grad = False
    for p in discriminator.embed_trg_tokens.parameters():
        p.requires_grad = False

    # define optimizer
    g_optimizer = eval("torch.optim." + args.g_optimizer)(filter(
        lambda x: x.requires_grad, generator.parameters()),
                                                          args.g_learning_rate)

    d_optimizer = eval("torch.optim." + args.d_optimizer)(
        filter(lambda x: x.requires_grad, discriminator.parameters()),
        args.d_learning_rate,
        momentum=args.momentum,
        nesterov=True)

    # start joint training
    best_dev_loss = math.inf
    num_update = 0
    # main training loop
    for epoch_i in range(1, args.epochs + 1):
        logging.info("At {0}-th epoch.".format(epoch_i))

        seed = args.seed + epoch_i
        torch.manual_seed(seed)

        max_positions_train = (args.fixed_max_len, args.fixed_max_len)

        # Initialize dataloader, starting at batch_offset
        trainloader = dataset.train_dataloader(
            'train',
            max_tokens=args.max_tokens,
            max_sentences=args.joint_batch_size,
            max_positions=max_positions_train,
            # seed=seed,
            epoch=epoch_i,
            sample_without_replacement=args.sample_without_replacement,
            sort_by_source_size=(epoch_i <= args.curriculum),
            shard_id=args.distributed_rank,
            num_shards=args.distributed_world_size,
        )

        # reset meters
        for key, val in g_logging_meters.items():
            if val is not None:
                val.reset()
        for key, val in d_logging_meters.items():
            if val is not None:
                val.reset()

        for i, sample in enumerate(trainloader):

            # set training mode
            generator.train()
            discriminator.train()
            update_learning_rate(num_update, 8e4, args.g_learning_rate,
                                 args.lr_shrink, g_optimizer)

            if use_cuda:
                # wrap input tensors in cuda tensors
                sample = utils.make_variable(sample, cuda=cuda)

            ## part I: use gradient policy method to train the generator

            # use policy gradient training when random.random() > 50%
            if random.random() >= 0.5:

                print("Policy Gradient Training")

                sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 * 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64*50 = 3200
                prediction = torch.reshape(
                    prediction,
                    sample['net_input']['src_tokens'].shape)  # 64 X 50

                with torch.no_grad():
                    reward = discriminator(sample['net_input']['src_tokens'],
                                           prediction)  # 64 X 1

                train_trg_batch = sample['target']  # 64 x 50

                pg_loss = pg_criterion(sys_out_batch, train_trg_batch, reward,
                                       use_cuda)
                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']  # 64
                logging_loss = pg_loss / math.log(2)
                g_logging_meters['train_loss'].update(logging_loss.item(),
                                                      sample_size)
                logging.debug(
                    f"G policy gradient loss at batch {i}: {pg_loss.item():.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                pg_loss.backward()
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            else:
                # MLE training
                print("MLE Training")

                sys_out_batch = generator(sample)

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                train_trg_batch = sample['target'].view(-1)  # 64*50 = 3200

                loss = g_criterion(out_batch, train_trg_batch)

                sample_size = sample['target'].size(
                    0) if args.sentence_avg else sample['ntokens']
                nsentences = sample['target'].size(0)
                logging_loss = loss.data / sample_size / math.log(2)
                g_logging_meters['bsz'].update(nsentences)
                g_logging_meters['train_loss'].update(logging_loss,
                                                      sample_size)
                logging.debug(
                    f"G MLE loss at batch {i}: {g_logging_meters['train_loss'].avg:.3f}, lr={g_optimizer.param_groups[0]['lr']}"
                )
                g_optimizer.zero_grad()
                loss.backward()
                # all-reduce grads and rescale by grad_denom
                for p in generator.parameters():
                    if p.requires_grad:
                        p.grad.data.div_(sample_size)
                torch.nn.utils.clip_grad_norm_(generator.parameters(),
                                               args.clip_norm)
                g_optimizer.step()

            num_update += 1

            # part II: train the discriminator
            if num_update % 5 == 0:
                bsz = sample['target'].size(0)  # batch_size = 64

                src_sentence = sample['net_input'][
                    'src_tokens']  # 64 x max-len i.e 64 X 50

                # now train with machine translation output i.e generator output
                true_sentence = sample['target'].view(-1)  # 64*50 = 3200

                true_labels = Variable(
                    torch.ones(
                        sample['target'].size(0)).float())  # 64 length vector

                with torch.no_grad():
                    sys_out_batch = generator(sample)  # 64 X 50 X 6632

                out_batch = sys_out_batch.contiguous().view(
                    -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                _, prediction = out_batch.topk(1)
                prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                fake_labels = Variable(
                    torch.zeros(
                        sample['target'].size(0)).float())  # 64 length vector

                fake_sentence = torch.reshape(prediction,
                                              src_sentence.shape)  # 64 X 50
                true_sentence = torch.reshape(true_sentence,
                                              src_sentence.shape)
                if use_cuda:
                    fake_labels = fake_labels.cuda()
                    true_labels = true_labels.cuda()

                # fake_disc_out = discriminator(src_sentence, fake_sentence)  # 64 X 1
                # true_disc_out = discriminator(src_sentence, true_sentence)
                #
                # fake_d_loss = d_criterion(fake_disc_out.squeeze(1), fake_labels)
                # true_d_loss = d_criterion(true_disc_out.squeeze(1), true_labels)
                #
                # fake_acc = torch.sum(torch.round(fake_disc_out).squeeze(1) == fake_labels).float() / len(fake_labels)
                # true_acc = torch.sum(torch.round(true_disc_out).squeeze(1) == true_labels).float() / len(true_labels)
                # acc = (fake_acc + true_acc) / 2
                #
                # d_loss = fake_d_loss + true_d_loss
                if random.random() > 0.5:
                    fake_disc_out = discriminator(src_sentence, fake_sentence)
                    fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                              fake_labels)
                    fake_acc = torch.sum(
                        torch.round(fake_disc_out).squeeze(1) ==
                        fake_labels).float() / len(fake_labels)
                    d_loss = fake_d_loss
                    acc = fake_acc
                else:
                    true_disc_out = discriminator(src_sentence, true_sentence)
                    true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                              true_labels)
                    true_acc = torch.sum(
                        torch.round(true_disc_out).squeeze(1) ==
                        true_labels).float() / len(true_labels)
                    d_loss = true_d_loss
                    acc = true_acc

                d_logging_meters['train_acc'].update(acc)
                d_logging_meters['train_loss'].update(d_loss)
                logging.debug(
                    f"D training loss {d_logging_meters['train_loss'].avg:.3f}, acc {d_logging_meters['train_acc'].avg:.3f} at batch {i}"
                )
                d_optimizer.zero_grad()
                d_loss.backward()
                d_optimizer.step()

            if num_update % 10000 == 0:

                # validation
                # set validation mode
                generator.eval()
                discriminator.eval()
                # Initialize dataloader
                max_positions_valid = (args.fixed_max_len, args.fixed_max_len)
                valloader = dataset.eval_dataloader(
                    'valid',
                    max_tokens=args.max_tokens,
                    max_sentences=args.joint_batch_size,
                    max_positions=max_positions_valid,
                    skip_invalid_size_inputs_valid_test=True,
                    descending=
                    True,  # largest batch first to warm the caching allocator
                    shard_id=args.distributed_rank,
                    num_shards=args.distributed_world_size,
                )

                # reset meters
                for key, val in g_logging_meters.items():
                    if val is not None:
                        val.reset()
                for key, val in d_logging_meters.items():
                    if val is not None:
                        val.reset()

                for i, sample in enumerate(valloader):

                    with torch.no_grad():
                        if use_cuda:
                            # wrap input tensors in cuda tensors
                            sample = utils.make_variable(sample, cuda=cuda)

                        # generator validation
                        sys_out_batch = generator(sample)
                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632
                        dev_trg_batch = sample['target'].view(
                            -1)  # 64*50 = 3200

                        loss = g_criterion(out_batch, dev_trg_batch)
                        sample_size = sample['target'].size(
                            0) if args.sentence_avg else sample['ntokens']
                        loss = loss / sample_size / math.log(2)
                        g_logging_meters['valid_loss'].update(
                            loss, sample_size)
                        logging.debug(
                            f"G dev loss at batch {i}: {g_logging_meters['valid_loss'].avg:.3f}"
                        )

                        # discriminator validation
                        bsz = sample['target'].size(0)
                        src_sentence = sample['net_input']['src_tokens']
                        # train with half human-translation and half machine translation

                        true_sentence = sample['target']
                        true_labels = Variable(
                            torch.ones(sample['target'].size(0)).float())

                        with torch.no_grad():
                            sys_out_batch = generator(sample)

                        out_batch = sys_out_batch.contiguous().view(
                            -1, sys_out_batch.size(-1))  # (64 X 50) X 6632

                        _, prediction = out_batch.topk(1)
                        prediction = prediction.squeeze(1)  # 64 * 50 = 6632

                        fake_labels = Variable(
                            torch.zeros(sample['target'].size(0)).float())

                        fake_sentence = torch.reshape(
                            prediction, src_sentence.shape)  # 64 X 50
                        true_sentence = torch.reshape(true_sentence,
                                                      src_sentence.shape)
                        if use_cuda:
                            fake_labels = fake_labels.cuda()
                            true_labels = true_labels.cuda()

                        fake_disc_out = discriminator(src_sentence,
                                                      fake_sentence)  # 64 X 1
                        true_disc_out = discriminator(src_sentence,
                                                      true_sentence)

                        fake_d_loss = d_criterion(fake_disc_out.squeeze(1),
                                                  fake_labels)
                        true_d_loss = d_criterion(true_disc_out.squeeze(1),
                                                  true_labels)
                        d_loss = fake_d_loss + true_d_loss
                        fake_acc = torch.sum(
                            torch.round(fake_disc_out).squeeze(1) ==
                            fake_labels).float() / len(fake_labels)
                        true_acc = torch.sum(
                            torch.round(true_disc_out).squeeze(1) ==
                            true_labels).float() / len(true_labels)
                        acc = (fake_acc + true_acc) / 2
                        d_logging_meters['valid_acc'].update(acc)
                        d_logging_meters['valid_loss'].update(d_loss)
                        logging.debug(
                            f"D dev loss {d_logging_meters['valid_loss'].avg:.3f}, acc {d_logging_meters['valid_acc'].avg:.3f} at batch {i}"
                        )

                # torch.save(discriminator,
                #            open(checkpoints_path + f"numupdate_{num_update/10000}k.discri_{d_logging_meters['valid_loss'].avg:.3f}.pt",'wb'), pickle_module=dill)

                # if d_logging_meters['valid_loss'].avg < best_dev_loss:
                #     best_dev_loss = d_logging_meters['valid_loss'].avg
                #     torch.save(discriminator, open(checkpoints_path + "best_dmodel.pt", 'wb'), pickle_module=dill)

                torch.save(
                    generator,
                    open(
                        checkpoints_path +
                        f"numupdate_{num_update/10000}k.joint_{g_logging_meters['valid_loss'].avg:.3f}.pt",
                        'wb'),
                    pickle_module=dill)
Exemple #17
0
targeted_model = torch.nn.DataParallel(resnet_model.__dict__['resnet32']())
targeted_model.cuda()
targeted_model.load_state_dict(checkpoint['state_dict'])
targeted_model.eval()

# load the generator of adversarial gan
pretrained_generator_path = './models/lp_pretrained/netG_rl_epoch_20.pth'
pretrained_G = Generator().to(device)
pretrained_G.load_state_dict(torch.load(pretrained_generator_path))
pretrained_G.eval()

# load the discriminator of adversarial gan
pretrained_disciminator_path = './models/lp_pretrained/netDisc_rl_epoch_20.pth'
pretrained_Disc = Discriminator().to(device)
pretrained_Disc.load_state_dict(torch.load(pretrained_disciminator_path))
pretrained_Disc.eval()

# load the Pixel Valuation network
pretrained_pvrl_path = './models/lp_pretrained/netPv_rl_epoch_20.pth'
pretrained_PV = PVRL().to(device)
pretrained_PV.load_state_dict(torch.load(pretrained_pvrl_path))
pretrained_PV.eval()

# test adversarial examples in CIFAR10 training dataset
cifar_dataset = torchvision.datasets.CIFAR10('./data',
                                             train=True,
                                             transform=transforms.ToTensor(),
                                             download=True)
train_dataloader = DataLoader(cifar_dataset,
                              batch_size=batch_size,
                              shuffle=False,
Exemple #18
0
def semi_main(options):
    print('\nSemi-Supervised Learning!\n')

    # 1. Make sure the options are valid argparse CLI options indeed
    assert isinstance(options, argparse.Namespace)

    # 2. Set up the logger
    logging.basicConfig(level=str(options.loglevel).upper())

    # 3. Make sure the output dir `outf` exists
    _check_out_dir(options)

    # 4. Set the random state
    _set_random_state(options)

    # 5. Configure CUDA and Cudnn, set the global `device` for PyTorch
    device = _configure_cuda(options)

    # 6. Prepare the datasets and split it for semi-supervised learning
    if options.dataset != 'cifar10':
        raise NotImplementedError(
            'Semi-supervised learning only support CIFAR10 dataset at the moment!'
        )
    test_data_loader, semi_data_loader, train_data_loader = _prepare_semi_dataset(
        options)

    # 7. Set the parameters
    ngpu = int(options.ngpu)  # num of GPUs
    nz = int(
        options.nz)  # size of latent vector, also the number of the generators
    ngf = int(options.ngf)  # depth of feature maps through G
    ndf = int(options.ndf)  # depth of feature maps through D
    nc = int(options.nc
             )  # num of channels of the input images, 3 indicates color images
    M = int(options.mcmc)  # num of SGHMC chains run concurrently
    nd = int(options.nd)  # num of discriminators
    nsetz = int(options.nsetz)  # num of noise batches

    # 8. Special preparations for Bayesian GAN for Generators

    # In order to inject the SGHMAC into the training process, instead of pause the gradient descent at
    # each training step, which can be easily defined with static computation graph(Tensorflow), in PyTorch,
    # we have to move the Generator Sampling to the very beginning of the whole training process, and use
    # a trick that initializing all of the generators explicitly for later usages.
    Generator_chains = []
    for _ in range(nsetz):
        for __ in range(M):
            netG = Generator(ngpu, nz, ngf, nc).to(device)
            netG.apply(weights_init)
            Generator_chains.append(netG)

    logging.info(
        f'Showing the first generator of the Generator chain: \n {Generator_chains[0]}\n'
    )

    # 9. Special preparations for Bayesian GAN for Discriminators
    assert options.dataset == 'cifar10', 'Semi-supervised learning only support CIFAR10 dataset at the moment!'

    num_class = 10 + 1

    # To simplify the implementation we only consider the situation of 1 discriminator
    # if nd <= 1:
    #     netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    #     netD.apply(weights_init)
    # else:
    # Discriminator_chains = []
    # for _ in range(nd):
    #     for __ in range(M):
    #         netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    #         netD.apply(weights_init)
    #         Discriminator_chains.append(netD)

    netD = Discriminator(ngpu, ndf, nc, num_class=num_class).to(device)
    netD.apply(weights_init)
    logging.info(f'Showing the Discriminator model: \n {netD}\n')

    # 10. Loss function
    criterion = nn.CrossEntropyLoss()
    all_criterion = ComplementCrossEntropyLoss(except_index=0, device=device)

    # 11. Set up optimizers
    optimizerG_chains = [
        optim.Adam(netG.parameters(),
                   lr=options.lr,
                   betas=(options.beta1, 0.999)) for netG in Generator_chains
    ]

    # optimizerD_chains = [
    #     optim.Adam(netD.parameters(), lr=options.lr, betas=(options.beta1, 0.999)) for netD in Discriminator_chains
    # ]
    optimizerD = optim.Adam(netD.parameters(),
                            lr=options.lr,
                            betas=(options.beta1, 0.999))
    import math
    # 12. Set up the losses for priors and noises
    gprior = PriorLoss(prior_std=1., total=500.)
    gnoise = NoiseLoss(params=Generator_chains[0].parameters(),
                       device=device,
                       scale=math.sqrt(2 * options.alpha / options.lr),
                       total=500.)
    dprior = PriorLoss(prior_std=1., total=50000.)
    dnoise = NoiseLoss(params=netD.parameters(),
                       device=device,
                       scale=math.sqrt(2 * options.alpha * options.lr),
                       total=50000.)

    gprior.to(device=device)
    gnoise.to(device=device)
    dprior.to(device=device)
    dnoise.to(device=device)

    # In order to let G condition on a specific noise, we attach the noise to a fixed Tensor
    fixed_noise = torch.FloatTensor(options.batchSize, options.nz, 1,
                                    1).normal_(0, 1).to(device=device)
    inputT = torch.FloatTensor(options.batchSize, 3, options.imageSize,
                               options.imageSize).to(device=device)
    noiseT = torch.FloatTensor(options.batchSize, options.nz, 1,
                               1).to(device=device)
    labelT = torch.FloatTensor(options.batchSize).to(device=device)
    real_label = 1
    fake_label = 0

    # 13. Transfer all the tensors and modules to GPU if applicable
    # for netD in Discriminator_chains:
    #     netD.to(device=device)
    netD.to(device=device)

    for netG in Generator_chains:
        netG.to(device=device)
    criterion.to(device=device)
    all_criterion.to(device=device)

    # ========================
    # === Training Process ===
    # ========================

    # Lists to keep track of progress
    img_list = []
    G_losses = []
    D_losses = []
    stats = []
    iters = 0

    try:
        print("\nStarting Training Loop...\n")
        for epoch in range(options.niter):
            top1 = Metrics()
            for i, data in enumerate(train_data_loader, 0):
                # ##################
                # Train with real
                # ##################
                netD.zero_grad()
                real_cpu = data[0].to(device)
                batch_size = real_cpu.size(0)
                # label = torch.full((batch_size,), real_label, device=device)

                inputT.resize_as_(real_cpu).copy_(real_cpu)
                labelT.resize_(batch_size).fill_(real_label)

                inputv = torch.autograd.Variable(inputT)
                labelv = torch.autograd.Variable(labelT)

                output = netD(inputv)
                errD_real = all_criterion(output)
                errD_real.backward()
                D_x = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()

                # ##################
                # Train with fake
                # ##################
                fake_images = []
                for i_z in range(nsetz):
                    noiseT.resize_(batch_size, nz, 1, 1).normal_(
                        0, 1)  # prior, sample from N(0, 1) distribution
                    noisev = torch.autograd.Variable(noiseT)
                    for m in range(M):
                        idx = i_z * M + m
                        netG = Generator_chains[idx]
                        _fake = netG(noisev)
                        fake_images.append(_fake)
                # output = torch.stack(fake_images)
                fake = torch.cat(fake_images)
                output = netD(fake.detach())

                labelv = torch.autograd.Variable(
                    torch.LongTensor(fake.data.shape[0]).to(
                        device=device).fill_(fake_label))
                errD_fake = criterion(output, labelv)
                errD_fake.backward()
                D_G_z1 = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()

                # ##################
                # Semi-supervised learning
                # ##################
                for ii, (input_sup, target_sup) in enumerate(semi_data_loader):
                    input_sup, target_sup = input_sup.to(
                        device=device), target_sup.to(device=device)
                    break
                input_sup_v = input_sup.to(device=device)
                target_sup_v = (target_sup + 1).to(device=device)
                output_sup = netD(input_sup_v)
                err_sup = criterion(output_sup, target_sup_v)
                err_sup.backward()
                pred1 = accuracy(output_sup.data, target_sup + 1,
                                 topk=(1, ))[0]
                top1.update(value=pred1.item(), N=input_sup.size(0))

                errD_prior = dprior(netD.parameters())
                errD_prior.backward()
                errD_noise = dnoise(netD.parameters())
                errD_noise.backward()
                errD = errD_real + errD_fake + err_sup + errD_prior + errD_noise
                optimizerD.step()

                # ##################
                # Sample and construct generator(s)
                # ##################
                for netG in Generator_chains:
                    netG.zero_grad()
                labelv = torch.autograd.Variable(
                    torch.FloatTensor(fake.data.shape[0]).to(
                        device=device).fill_(real_label))
                output = netD(fake)
                errG = all_criterion(output)

                for netG in Generator_chains:
                    errG = errG + gprior(netG.parameters())
                    errG = errG + gnoise(netG.parameters())
                errG.backward()
                D_G_z2 = 1 - torch.nn.functional.softmax(
                    output).data[:, 0].mean().item()
                for optimizerG in optimizerG_chains:
                    optimizerG.step()

                # ##################
                # Evaluate testing accuracy
                # ##################
                # Pause and compute the test accuracy after every 10 times of the notefreq
                if iters % 10 * int(options.notefreq) == 0:
                    # get test accuracy on train and test
                    netD.eval()
                    compute_test_accuracy(discriminator=netD,
                                          testing_data_loader=test_data_loader,
                                          device=device)
                    netD.train()

                # ##################
                # Note down
                # ##################
                # Report status for the current iteration
                training_status = f"[{epoch}/{options.niter}][{i}/{len(train_data_loader)}] Loss_D: {errD.item():.4f} " \
                                  f"Loss_G: " \
                                  f"{errG.item():.4f} D(x): {D_x:.4f} D(G(z)): {D_G_z1:.4f}/{D_G_z2:.4f}" \
                                  f" | Acc {top1.value:.1f} / {top1.mean:.1f}"
                print(training_status)

                # Save samples to disk
                if i % int(options.notefreq) == 0:
                    vutils.save_image(
                        real_cpu,
                        f"{options.outf}/real_samples_epoch_{epoch:{0}{3}}_{i}.png",
                        normalize=True)
                    for _iz in range(nsetz):
                        for _m in range(M):
                            gidx = _iz * M + _m
                            netG = Generator_chains[gidx]
                            fake = netG(fixed_noise)
                            vutils.save_image(
                                fake.detach(),
                                f"{options.outf}/fake_samples_epoch_{epoch:{0}{3}}_{i}_z{_iz}_m{_m}.png",
                                normalize=True)

                    # Save Losses statistics for post-mortem
                    G_losses.append(errG.item())
                    D_losses.append(errD.item())
                    stats.append(training_status)

                    # # Check how the generator is doing by saving G's output on fixed_noise
                    # if (iters % 500 == 0) or ((epoch == options.niter - 1) and (i == len(data_loader) - 1)):
                    #     with torch.no_grad():
                    #         fake = netG(fixed_noise).detach().cpu()
                    #     img_list.append(vutils.make_grid(fake, padding=2, normalize=True))

                    iters += 1
            # TODO: find an elegant way to support saving checkpoints in Bayesian GAN context
    except Exception as e:
        print(e)

        # save training stats no matter what kind of errors occur in the processes
        _save_stats(statistic=G_losses, save_name='G_losses', options=options)
        _save_stats(statistic=D_losses, save_name='D_losses', options=options)
        _save_stats(statistic=stats,
                    save_name='Training_stats',
                    options=options)
Exemple #19
0
class AttnGAN:
    def __init__(self, damsm, device=DEVICE):
        self.gen = Generator(device)
        self.disc = Discriminator(device)
        self.damsm = damsm.to(device)
        self.damsm.txt_enc.eval(), self.damsm.img_enc.eval()
        freeze_params_(self.damsm.txt_enc), freeze_params_(self.damsm.img_enc)

        self.device = device
        self.gen.apply(init_weights), self.disc.apply(init_weights)

        self.gen_optimizer = torch.optim.Adam(self.gen.parameters(),
                                              lr=GENERATOR_LR,
                                              betas=(0.5, 0.999))

        self.discriminators = [self.disc.d64, self.disc.d128, self.disc.d256]
        self.disc_optimizers = [
            torch.optim.Adam(d.parameters(),
                             lr=DISCRIMINATOR_LR,
                             betas=(0.5, 0.999)) for d in self.discriminators
        ]

    #@torch.no_grad()
    def train(self,
              dataset,
              epoch,
              batch_size=GAN_BATCH,
              test_sample_every=5,
              hist_avg=False,
              evaluator=None):

        start_time = time.strftime("%Y-%m-%d-%H-%M", time.gmtime())
        os.makedirs(f'{OUT_DIR}/{start_time}')

        # print('cun')
        # for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
        #     self.gen.eval()
        #     generated_samples = [resolution.unsqueeze(0) for resolution in self.sample_test_set(dataset)]
        #     self._save_generated(generated_samples, e, f'{OUT_DIR}/{start_time}')
        #
        #     return

        if hist_avg:
            avg_g_params = deepcopy(list(p.data
                                         for p in self.gen.parameters()))

        loader_config = {
            'batch_size': batch_size,
            'shuffle': True,
            'drop_last': True,
            'collate_fn': dataset.collate_fn
        }

        train_loader = DataLoader(dataset.train, **loader_config)

        metrics = {
            'IS': [],
            'FID': [],
            'loss': {
                'g': [],
                'd': []
            },
            'accuracy': {
                'real': [],
                'fake': [],
                'mismatched': [],
                'unconditional_real': [],
                'unconditional_fake': []
            }
        }

        if evaluator is not None:
            evaluator = evaluator(dataset, self.damsm.img_enc.inception_model,
                                  batch_size, self.device)

        noise = torch.FloatTensor(batch_size, D_Z).to(self.device)
        gen_updates = 0

        self.disc.train()

        for e in tqdm(range(epoch), desc='Epochs', dynamic_ncols=True):
            self.gen.train(), self.disc.train()
            g_loss = 0
            w_loss = 0
            s_loss = 0
            kl_loss = 0
            g_stage_loss = np.zeros(3, dtype=float)
            d_loss = np.zeros(3, dtype=float)
            real_acc = np.zeros(3, dtype=float)
            fake_acc = np.zeros(3, dtype=float)
            mismatched_acc = np.zeros(3, dtype=float)
            uncond_real_acc = np.zeros(3, dtype=float)
            uncond_fake_acc = np.zeros(3, dtype=float)
            disc_skips = np.zeros(3, dtype=int)

            train_pbar = tqdm(train_loader,
                              desc='Training',
                              leave=False,
                              dynamic_ncols=True)
            for batch in train_pbar:
                real_imgs = [batch['img64'], batch['img128'], batch['img256']]

                with torch.no_grad():
                    word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                # Generate images
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                # Discriminator loss (with label smoothing)
                batch_d_loss, batch_real_acc, batch_fake_acc, batch_mismatched_acc, batch_uncond_real_acc, batch_uncond_fake_acc, batch_disc_skips = self.discriminator_step(
                    real_imgs, generated, sent_embs, 0.1)

                d_grad_norm = [grad_norm(d) for d in self.discriminators]

                d_loss += batch_d_loss
                real_acc += batch_real_acc
                fake_acc += batch_fake_acc
                mismatched_acc += batch_mismatched_acc
                uncond_real_acc += batch_uncond_real_acc
                uncond_fake_acc += batch_uncond_fake_acc
                disc_skips += batch_disc_skips

                # Generator loss
                batch_g_losses = self.generator_step(generated, word_embs,
                                                     sent_embs, mu, logvar,
                                                     batch['label'])
                g_total, batch_g_stage_loss, batch_w_loss, batch_s_loss, batch_kl_loss = batch_g_losses
                g_stage_loss += batch_g_stage_loss
                w_loss += batch_w_loss
                s_loss += (batch_s_loss)
                kl_loss += (batch_kl_loss)
                gen_updates += 1

                avg_g_loss = g_total.item() / batch_size
                g_loss += float(avg_g_loss)

                if hist_avg:
                    for p, avg_p in zip(self.gen.parameters(), avg_g_params):
                        avg_p.mul_(0.999).add_(0.001, p.data)

                    if gen_updates % 1000 == 0:
                        tqdm.write(
                            'Replacing generator weights with their moving average'
                        )
                        for p, avg_p in zip(self.gen.parameters(),
                                            avg_g_params):
                            p.data.copy_(avg_p)

                train_pbar.set_description(
                    f'Training (G: {grad_norm(self.gen):.2f}  '
                    f'D64: {d_grad_norm[0]:.2f}  '
                    f'D128: {d_grad_norm[1]:.2f}  '
                    f'D256: {d_grad_norm[2]:.2f})')

            batches = len(train_loader)

            g_loss /= batches
            g_stage_loss /= batches
            w_loss /= batches
            s_loss /= batches
            kl_loss /= batches
            d_loss /= batches
            real_acc /= batches
            fake_acc /= batches
            mismatched_acc /= batches
            uncond_real_acc /= batches
            uncond_fake_acc /= batches

            metrics['loss']['g'].append(g_loss)
            metrics['loss']['d'].append(d_loss)
            metrics['accuracy']['real'].append(real_acc)
            metrics['accuracy']['fake'].append(fake_acc)
            metrics['accuracy']['mismatched'].append(mismatched_acc)
            metrics['accuracy']['unconditional_real'].append(uncond_real_acc)
            metrics['accuracy']['unconditional_fake'].append(uncond_fake_acc)

            sep = '_' * 10
            tqdm.write(f'{sep}Epoch {e}{sep}')

            if e % test_sample_every == 0:
                self.gen.eval()
                generated_samples = [
                    resolution.unsqueeze(0)
                    for resolution in self.sample_test_set(dataset)
                ]
                self._save_generated(generated_samples, e,
                                     f'{OUT_DIR}/{start_time}')

                if evaluator is not None:
                    scores = evaluator.evaluate(self)
                    for k, v in scores.items():
                        metrics[k].append(v)
                        tqdm.write(f'{k}: {v:.2f}')

            tqdm.write(
                f'Generator avg loss: total({g_loss:.3f})  '
                f'stage0({g_stage_loss[0]:.3f})  stage1({g_stage_loss[1]:.3f})  stage2({g_stage_loss[2]:.3f})  '
                f'w({w_loss:.3f})  s({s_loss:.3f})  kl({kl_loss:.3f})')

            for i, _ in enumerate(self.discriminators):
                tqdm.write(f'Discriminator{i} avg: '
                           f'loss({d_loss[i]:.3f})  '
                           f'r-acc({real_acc[i]:.3f})  '
                           f'f-acc({fake_acc[i]:.3f})  '
                           f'm-acc({mismatched_acc[i]:.3f})  '
                           f'ur-acc({uncond_real_acc[i]:.3f})  '
                           f'uf-acc({uncond_fake_acc[i]:.3f})  '
                           f'skips({disc_skips[i]})')

        return metrics

    def sample_test_set(self,
                        dataset,
                        nb_samples=8,
                        nb_captions=2,
                        noise_variations=2):
        subset = dataset.test
        sample_indices = np.random.choice(len(subset),
                                          nb_samples,
                                          replace=False)
        cap_indices = np.random.choice(10, nb_captions, replace=False)
        texts = [
            subset.data[f'caption_{cap_idx}'].iloc[sample_idx]
            for sample_idx in sample_indices for cap_idx in cap_indices
        ]

        generated_samples = [
            self.generate_from_text(texts, dataset)
            for _ in range(noise_variations)
        ]
        combined_img64 = torch.FloatTensor()
        combined_img128 = torch.FloatTensor()
        combined_img256 = torch.FloatTensor()

        for noise_variant in generated_samples:
            noise_var_img64 = torch.FloatTensor()
            noise_var_img128 = torch.FloatTensor()
            noise_var_img256 = torch.FloatTensor()
            for i in range(nb_samples):
                # rows: samples, columns: captions * noise variants
                row64 = torch.cat([
                    noise_variant[0][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                  dim=-1).cpu()
                row128 = torch.cat([
                    noise_variant[1][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                   dim=-1).cpu()
                row256 = torch.cat([
                    noise_variant[2][i * nb_captions + j]
                    for j in range(nb_captions)
                ],
                                   dim=-1).cpu()
                noise_var_img64 = torch.cat([noise_var_img64, row64], dim=-2)
                noise_var_img128 = torch.cat([noise_var_img128, row128],
                                             dim=-2)
                noise_var_img256 = torch.cat([noise_var_img256, row256],
                                             dim=-2)
            combined_img64 = torch.cat([combined_img64, noise_var_img64],
                                       dim=-1)
            combined_img128 = torch.cat([combined_img128, noise_var_img128],
                                        dim=-1)
            combined_img256 = torch.cat([combined_img256, noise_var_img256],
                                        dim=-1)

        return combined_img64, combined_img128, combined_img256

    @staticmethod
    def KL_loss(mu, logvar):
        loss = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar)
        loss = torch.mean(loss).mul_(-0.5)
        return loss

    def generator_step(self, generated_imgs, word_embs, sent_embs, mu, logvar,
                       class_labels):
        self.gen.zero_grad()
        avg_stage_g_loss = [0, 0, 0]

        local_features, global_features = self.damsm.img_enc(
            generated_imgs[-1])
        batch_size = sent_embs.size(0)
        match_labels = torch.LongTensor(range(batch_size)).to(self.device)

        w1_loss, w2_loss, _ = self.damsm.words_loss(local_features, word_embs,
                                                    class_labels, match_labels)
        w_loss = (w1_loss + w2_loss) * LAMBDA

        s1_loss, s2_loss = self.damsm.sentence_loss(global_features, sent_embs,
                                                    class_labels, match_labels)
        s_loss = (s1_loss + s2_loss) * LAMBDA

        kl_loss = self.KL_loss(mu, logvar)

        g_total = w_loss + s_loss + kl_loss

        for i, d in enumerate(self.discriminators):
            features = d(generated_imgs[i])
            fake_logits = d.logit(features, sent_embs)

            real_labels = torch.ones_like(fake_logits).to(self.device)

            disc_error = F.binary_cross_entropy_with_logits(
                fake_logits, real_labels)

            uncond_fake_logits = d.logit(features)
            uncond_disc_error = F.binary_cross_entropy_with_logits(
                uncond_fake_logits, real_labels)

            stage_loss = disc_error + uncond_disc_error
            avg_stage_g_loss[i] = stage_loss.item() / batch_size
            g_total += stage_loss

        g_total.backward()
        self.gen_optimizer.step()

        return g_total, avg_stage_g_loss, w_loss.item(
        ) / batch_size, s_loss.item() / batch_size, kl_loss.item()

    def discriminator_step(self,
                           real_imgs,
                           generated_imgs,
                           sent_embs,
                           label_smoothing,
                           skip_acc_threshold=0.9,
                           p_flip=0.05,
                           halting=False):
        self.disc.zero_grad()
        batch_size = sent_embs.size(0)

        avg_d_loss = [0, 0, 0]
        real_accuracy = [0, 0, 0]
        fake_accuracy = [0, 0, 0]
        mismatched_accuracy = [0, 0, 0]
        uncond_real_accuracy = [0, 0, 0]
        uncond_fake_accuracy = [0, 0, 0]
        skipped = [0, 0, 0]

        for i, d in enumerate(self.discriminators):
            real_features = d(real_imgs[i].to(self.device))
            fake_features = d(generated_imgs[i].detach())

            real_logits = d.logit(real_features, sent_embs)

            real_labels = torch.full_like(real_logits,
                                          1 - label_smoothing).to(self.device)
            fake_labels = torch.zeros_like(real_logits,
                                           dtype=torch.float).to(self.device)

            # flip_mask = torch.Tensor(real_labels.size()).bernoulli_(p_flip).type(torch.bool)
            # real_labels[flip_mask], fake_labels[flip_mask] = fake_labels[flip_mask], real_labels[flip_mask]

            real_error = F.binary_cross_entropy_with_logits(
                real_logits, real_labels)
            # Real images should be classified as real
            real_accuracy[i] = (real_logits >=
                                0).sum().item() / real_logits.numel()

            fake_logits = d.logit(fake_features, sent_embs)
            fake_error = F.binary_cross_entropy_with_logits(
                fake_logits, fake_labels)
            # Generated images should be classified as fake
            fake_accuracy[i] = (fake_logits <
                                0).sum().item() / fake_logits.numel()

            mismatched_logits = d.logit(real_features,
                                        rotate_tensor(sent_embs, 1))
            mismatched_error = F.binary_cross_entropy_with_logits(
                mismatched_logits, fake_labels)
            # Images with mismatched descriptions should be classified as fake
            mismatched_accuracy[i] = (mismatched_logits < 0).sum().item(
            ) / mismatched_logits.numel()

            uncond_real_logits = d.logit(real_features)
            uncond_real_error = F.binary_cross_entropy_with_logits(
                uncond_real_logits, real_labels)
            uncond_real_accuracy[i] = (uncond_real_logits >= 0).sum().item(
            ) / uncond_real_logits.numel()

            uncond_fake_logits = d.logit(fake_features)
            uncond_fake_error = F.binary_cross_entropy_with_logits(
                uncond_fake_logits, fake_labels)
            uncond_fake_accuracy[i] = (uncond_fake_logits < 0).sum().item(
            ) / uncond_fake_logits.numel()

            error = (real_error + uncond_real_error) / 2 + (
                fake_error + uncond_fake_error + mismatched_error) / 3

            if not halting or fake_accuracy[i] + real_accuracy[
                    i] < skip_acc_threshold * 2:
                error.backward()
                self.disc_optimizers[i].step()
            else:
                skipped[i] = 1

            avg_d_loss[i] = error.item() / batch_size

        return avg_d_loss, real_accuracy, fake_accuracy, mismatched_accuracy, uncond_real_accuracy, uncond_fake_accuracy, skipped

    def generate_from_text(self, texts, dataset, noise=None):
        encoded = [dataset.train.encode_text(t) for t in texts]
        generated = self.generate_from_encoded_text(encoded, dataset, noise)
        return generated

    def generate_from_encoded_text(self, encoded, dataset, noise=None):
        with torch.no_grad():
            w_emb, s_emb = self.damsm.txt_enc(encoded)
            attn_mask = torch.tensor(encoded).to(
                self.device) == dataset.vocab[END_TOKEN]
            if noise is None:
                noise = torch.FloatTensor(len(encoded), D_Z).to(self.device)
                noise.data.normal_(0, 1)
            generated, att, mu, logvar = self.gen(noise, s_emb, w_emb,
                                                  attn_mask)
        return generated

    def _save_generated(self, generated, epoch, out_dir=OUT_DIR):
        nb_samples = generated[0].size(0)
        save_dir = f'{out_dir}/epoch_{epoch:03}'
        os.makedirs(save_dir)

        for i in range(nb_samples):
            save_image(generated[0][i],
                       f'{save_dir}/{i}_64.jpg',
                       normalize=True,
                       range=(-1, 1))
            save_image(generated[1][i],
                       f'{save_dir}/{i}_128.jpg',
                       normalize=True,
                       range=(-1, 1))
            save_image(generated[2][i],
                       f'{save_dir}/{i}_256.jpg',
                       normalize=True,
                       range=(-1, 1))

    def save(self, name, save_dir=GAN_MODEL_DIR, metrics=None):
        os.makedirs(save_dir, exist_ok=True)
        torch.save(self.gen.state_dict(), f'{save_dir}/{name}_generator.pt')
        torch.save(self.disc.state_dict(),
                   f'{save_dir}/{name}_discriminator.pt')
        if metrics is not None:
            with open(f'{save_dir}/{name}_metrics.json', 'w') as f:
                metrics = pre_json_metrics(metrics)
                json.dump(metrics, f)

    def load_(self, name, load_dir=GAN_MODEL_DIR):
        self.gen.load_state_dict(torch.load(f'{load_dir}/{name}_generator.pt'))
        self.disc.load_state_dict(
            torch.load(f'{load_dir}/{name}_discriminator.pt'))
        self.gen.eval(), self.disc.eval()

    @staticmethod
    def load(name, damsm, load_dir=GAN_MODEL_DIR, device=DEVICE):
        attngan = AttnGAN(damsm, device=device)
        attngan.load_(name, load_dir)
        return attngan

    def validate_test_set(self,
                          dataset,
                          batch_size=GAN_BATCH,
                          save_dir=f'{OUT_DIR}/test_samples'):
        os.makedirs(save_dir, exist_ok=True)

        loader = DataLoader(dataset.test,
                            batch_size=batch_size,
                            shuffle=True,
                            drop_last=False,
                            collate_fn=dataset.collate_fn)
        loader = tqdm(loader,
                      dynamic_ncols=True,
                      leave=True,
                      desc='Generating samples for test set')

        self.gen.eval()
        with torch.no_grad():
            i = 0
            for batch in loader:
                word_embs, sent_embs = self.damsm.txt_enc(batch['caption'])
                attn_mask = torch.tensor(batch['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]
                noise = torch.FloatTensor(len(batch['caption']),
                                          D_Z).to(self.device)
                noise.data.normal_(0, 1)
                generated, att, mu, logvar = self.gen(noise, sent_embs,
                                                      word_embs, attn_mask)

                for img in generated[-1]:
                    save_image(img,
                               f'{save_dir}/{i}.jpg',
                               normalize=True,
                               range=(-1, 1))
                    i += 1

    def get_d_score(self, imgs, sent_embs):
        d = self.disc.d256
        features = d(imgs.to(self.device))
        scores = d.logit(features, sent_embs)
        return scores

    def accept_prob(self, score1, score2):
        return min(1, (1 / score1 - 1) / (1 / score2 - 1))

    def d_scores_test(self, dataset):
        with torch.no_grad():
            loader = DataLoader(dataset.test,
                                batch_size=20,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            scores = []
            d = self.disc.d256
            for b in loader:
                img = b['img256'].to(self.device)
                f = d(img)
                l = d.logit(f)
                scores.append(torch.sigmoid(l))
            scores = [x.item() for s in scores for x in s.reshape(-1)]
        return scores

    def z_test(self, scores, labels):
        labels = np.array(labels)
        scores = np.array(scores)
        num = np.sum(labels - scores)
        denom = np.sqrt(np.sum(scores * (1 - scores)))
        return num / denom

    def d_scores_gen(self, dataset):
        with torch.no_grad():
            loader = DataLoader(dataset.test,
                                batch_size=20,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            scores = []
            d = self.disc.d256
            for b in loader:
                noise = torch.FloatTensor(len(b['caption']),
                                          D_Z).to(self.device)
                noise.data.normal_(0, 1)
                word_embs, sent_embs = self.damsm.txt_enc(b['caption'])
                attn_mask = torch.tensor(b['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]
                generated, _, _, _ = self.gen(noise, sent_embs, word_embs,
                                              attn_mask)

                f = d(generated[-1])
                l = d.logit(f)
                scores.append(torch.sigmoid(l))
            scores = [x.item() for s in scores for x in s.reshape(-1)]
        return scores

    def mh_sample(self, dataset, k, save_dir='test_samples', batch=GAN_BATCH):
        evaluator = IS_FID_Evaluator(dataset,
                                     self.damsm.img_enc.inception_model, batch,
                                     self.device)
        # self.disc.d256.train()
        with torch.no_grad():
            l = len(dataset.test)
            score_real = self.d_scores_test(dataset)
            score_gen = self.d_scores_gen(dataset)
            print(np.mean(score_real))
            print(np.mean(score_gen))
            portion = -l // 5
            score_test = score_real[:portion] + score_gen[:portion]
            label_test = [1] * (len(score_test) //
                                2) + [0] * (len(score_test) // 2)

            print('Z test before calibration: ',
                  self.z_test(torch.tensor(score_test), label_test))

            score_real_calib = score_real[portion:]
            score_gen_calib = score_gen[portion:]
            # score_calib = score_real_calib + score_gen_calib
            score_calib = score_gen_calib + score_real_calib
            label_calib = len(score_gen_calib) * [0] + len(
                score_real_calib) * [1]

            cal_clf = LogisticRegression()
            cal_clf.fit(np.array(score_calib).reshape(-1, 1), label_calib)

            score_pred = cal_clf.predict_proba(
                np.array(score_test).reshape(-1, 1))[:, 1]
            print('Score pred avg: ', np.mean(score_pred))
            test_pred = cal_clf.predict(np.array(score_test).reshape(-1, 1))

            print('Z test after calibration: ',
                  self.z_test(score_pred, label_test))
            print('Accuracy: ',
                  sum((test_pred == label_test)) / len(test_pred))

            os.makedirs(save_dir, exist_ok=True)
            loader = DataLoader(dataset.test,
                                batch_size=1,
                                shuffle=False,
                                drop_last=False,
                                collate_fn=dataset.collate_fn)
            loader = tqdm(loader,
                          dynamic_ncols=True,
                          leave=True,
                          desc='Generating samples for test set')

            imgs = []
            true_probs = 0
            noaccept = 0
            for i, sample in enumerate(loader):
                if i > l - (l // 10):
                    continue
                word_embs, sent_embs = self.damsm.txt_enc(sample['caption'])
                attn_mask = torch.tensor(sample['caption']).to(
                    self.device) == dataset.vocab[END_TOKEN]

                img_chain = []
                while len(img_chain) < k:
                    noise = torch.FloatTensor(batch, D_Z).to(self.device)
                    noise.data.normal_(0, 1)
                    generated, _, _, _ = self.gen(
                        noise, sent_embs.repeat(batch, 1),
                        word_embs.repeat(batch, 1, 1),
                        attn_mask.repeat(batch, 1))

                    for img in generated[-1]:
                        img_chain.append(img)

                img_chain = img_chain[:k]
                img_chain = torch.stack(img_chain).to(self.device)

                score_chain = []
                d_loader = DataLoader(img_chain,
                                      batch_size=batch,
                                      shuffle=False,
                                      drop_last=False)
                for d_batch in d_loader:
                    scores = self.get_d_score(d_batch,
                                              sent_embs.repeat(batch, 1))
                    scores = scores.reshape(-1, 1).cpu().numpy()
                    scores = cal_clf.predict_proba(scores)[:, 1]
                    for s in scores:
                        score_chain.append(s)
                chosen = 0
                for j, s in enumerate(score_chain[1:], 1):
                    alpha = self.accept_prob(score_chain[chosen], s)
                if np.random.rand() < alpha:
                    chosen = j

                if chosen == 0:
                    imgs.append(img_chain[torch.tensor(
                        score_chain[1:]).argmax()].cpu())
                    noaccept += 1
                else:
                    imgs.append(img_chain[chosen].cpu())
                true_probs += score_chain[0]

            print(noaccept)
            print(true_probs / len(dataset.test))
            mu_real, sig_real = evaluator.mu_real, evaluator.sig_real
            mu_fake, sig_fake = activation_statistics(
                self.damsm.img_enc.inception_model, imgs)
            print('FID: ', frechet_dist(mu_real, sig_real, mu_fake, sig_fake))
            return imgs
Exemple #20
0
def train_D_With_G():
    aD = Discriminator()
    aD.cuda()

    aG = Generator()
    aG.cuda()

    optimizer_g = torch.optim.Adam(aG.parameters(), lr=0.0001, betas=(0, 0.9))
    optimizer_d = torch.optim.Adam(aD.parameters(), lr=0.0001, betas=(0, 0.9))

    criterion = nn.CrossEntropyLoss()

    n_z = 100
    n_classes = 10
    np.random.seed(352)
    label = np.asarray(list(range(10)) * 10)
    noise = np.random.normal(0, 1, (100, n_z))
    label_onehot = np.zeros((100, n_classes))
    label_onehot[np.arange(100), label] = 1
    noise[np.arange(100), :n_classes] = label_onehot[np.arange(100)]
    noise = noise.astype(np.float32)

    save_noise = torch.from_numpy(noise)
    save_noise = Variable(save_noise).cuda()
    start_time = time.time()

    # Train the model
    num_epochs = 500
    loss1 = []
    loss2 = []
    loss3 = []
    loss4 = []
    loss5 = []
    acc1 = []
    for epoch in range(0, num_epochs):

        aG.train()
        aD.train()
        avoidOverflow(optimizer_d)
        avoidOverflow(optimizer_g)
        for batch_idx, (X_train_batch,
                        Y_train_batch) in enumerate(trainloader):

            if (Y_train_batch.shape[0] < batch_size):
                continue
            # train G
            if batch_idx % gen_train == 0:
                for p in aD.parameters():
                    p.requires_grad_(False)

                aG.zero_grad()

                label = np.random.randint(0, n_classes, batch_size)
                noise = np.random.normal(0, 1, (batch_size, n_z))
                label_onehot = np.zeros((batch_size, n_classes))
                label_onehot[np.arange(batch_size), label] = 1
                noise[np.arange(batch_size), :n_classes] = label_onehot[
                    np.arange(batch_size)]
                noise = noise.astype(np.float32)
                noise = torch.from_numpy(noise)
                noise = Variable(noise).cuda()
                fake_label = Variable(torch.from_numpy(label)).cuda()

                fake_data = aG(noise)
                gen_source, gen_class = aD(fake_data)

                gen_source = gen_source.mean()
                gen_class = criterion(gen_class, fake_label)

                gen_cost = -gen_source + gen_class
                gen_cost.backward()

                optimizer_g.step()

            # train D
            for p in aD.parameters():
                p.requires_grad_(True)

            aD.zero_grad()

            # train discriminator with input from generator
            label = np.random.randint(0, n_classes, batch_size)
            noise = np.random.normal(0, 1, (batch_size, n_z))
            label_onehot = np.zeros((batch_size, n_classes))
            label_onehot[np.arange(batch_size), label] = 1
            noise[np.arange(batch_size), :n_classes] = label_onehot[np.arange(
                batch_size)]
            noise = noise.astype(np.float32)
            noise = torch.from_numpy(noise)
            noise = Variable(noise).cuda()
            fake_label = Variable(torch.from_numpy(label)).cuda()
            with torch.no_grad():
                fake_data = aG(noise)

            disc_fake_source, disc_fake_class = aD(fake_data)

            disc_fake_source = disc_fake_source.mean()
            disc_fake_class = criterion(disc_fake_class, fake_label)

            # train discriminator with input from the discriminator
            real_data = Variable(X_train_batch).cuda()
            real_label = Variable(Y_train_batch).cuda()

            disc_real_source, disc_real_class = aD(real_data)

            prediction = disc_real_class.data.max(1)[1]
            accuracy = (float(prediction.eq(real_label.data).sum()) /
                        float(batch_size)) * 100.0

            disc_real_source = disc_real_source.mean()
            disc_real_class = criterion(disc_real_class, real_label)

            gradient_penalty = calc_gradient_penalty(aD, real_data, fake_data)

            disc_cost = disc_fake_source - disc_real_source + disc_real_class + disc_fake_class + gradient_penalty
            disc_cost.backward()

            optimizer_d.step()
            loss1.append(gradient_penalty.item())
            loss2.append(disc_fake_source.item())
            loss3.append(disc_real_source.item())
            loss4.append(disc_real_class.item())
            loss5.append(disc_fake_class.item())
            acc1.append(accuracy)
            if batch_idx % 50 == 0:
                print(epoch, batch_idx, "%.2f" % np.mean(loss1),
                      "%.2f" % np.mean(loss2), "%.2f" % np.mean(loss3),
                      "%.2f" % np.mean(loss4), "%.2f" % np.mean(loss5),
                      "%.2f" % np.mean(acc1))
        # Test the model
        aD.eval()
        with torch.no_grad():
            test_accu = []
            for batch_idx, (X_test_batch,
                            Y_test_batch) in enumerate(testloader):
                X_test_batch, Y_test_batch = Variable(
                    X_test_batch).cuda(), Variable(Y_test_batch).cuda()

                with torch.no_grad():
                    _, output = aD(X_test_batch)

                prediction = output.data.max(1)[
                    1]  # first column has actual prob.
                accuracy = (float(prediction.eq(Y_test_batch.data).sum()) /
                            float(batch_size)) * 100.0
                test_accu.append(accuracy)
                accuracy_test = np.mean(test_accu)
        print('Testing', accuracy_test, time.time() - start_time)

        # save output
        with torch.no_grad():
            aG.eval()
            samples = aG(save_noise)
            samples = samples.data.cpu().numpy()
            samples += 1.0
            samples /= 2.0
            samples = samples.transpose(0, 2, 3, 1)
            aG.train()
        fig = plot(samples)
        plt.savefig('output/%s.png' % str(epoch).zfill(3), bbox_inches='tight')
        plt.close(fig)

        if (epoch + 1) % 1 == 0:
            torch.save(aG, 'tempG.model')
            torch.save(aD, 'tempD.model')

    torch.save(aG, 'generator.model')
    torch.save(aD, 'discriminator.model')
class Trainer:
    def __init__(self, params, *, n_samples=10000000):
        self.model = [
            fastText.load_model(
                os.path.join(params.dataDir, params.model_path_0)),
            fastText.load_model(
                os.path.join(params.dataDir, params.model_path_1))
        ]
        self.dic = [
            list(zip(*self.model[id].get_words(include_freq=True)))
            for id in [0, 1]
        ]
        x = [
            np.empty((params.vocab_size, params.emb_dim), dtype=np.float64)
            for _ in [0, 1]
        ]
        for id in [0, 1]:
            for i in range(params.vocab_size):
                x[id][i, :] = self.model[id].get_word_vector(
                    self.dic[id][i][0])
            x[id] = normalize_embeddings_np(x[id], params.normalize_pre)
        u0, s0, _ = scipy.linalg.svd(x[0], full_matrices=False)
        u1, s1, _ = scipy.linalg.svd(x[1], full_matrices=False)
        if params.spectral_align_pre:
            s = (s0 + s1) * 0.5
            x[0] = u0 @ np.diag(s)
            x[1] = u1 @ np.diag(s)
        else:
            x[0] = u0 @ np.diag(s0)
            x[1] = u1 @ np.diag(s1)
        self.embedding = [
            nn.Embedding.from_pretrained(torch.from_numpy(x[id]).to(
                torch.float).to(GPU),
                                         freeze=True,
                                         sparse=True) for id in [0, 1]
        ]
        self.discriminator = Discriminator(
            params.emb_dim,
            n_layers=params.d_n_layers,
            n_units=params.d_n_units,
            drop_prob=params.d_drop_prob,
            drop_prob_input=params.d_drop_prob_input,
            leaky=params.d_leaky,
            batch_norm=params.d_bn).to(GPU)
        self.mapping = Mapping(params.emb_dim).to(GPU)
        if params.d_optimizer == "SGD":
            self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(
                self.discriminator.parameters(),
                lr=params.d_lr,
                mode="max",
                wd=params.d_wd)

        elif params.d_optimizer == "RMSProp":
            self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(
                self.discriminator.parameters(),
                params.n_steps,
                lr=params.d_lr,
                wd=params.d_wd)
        else:
            raise Exception(f"Optimizer {params.d_optimizer} not found.")
        if params.m_optimizer == "SGD":
            self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(
                self.mapping.parameters(),
                lr=params.m_lr,
                mode="max",
                wd=params.m_wd,
                factor=params.m_lr_decay,
                patience=params.m_lr_patience)
        elif params.m_optimizer == "RMSProp":
            self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(
                self.mapping.parameters(),
                params.n_steps,
                lr=params.m_lr,
                wd=params.m_wd)
        else:
            raise Exception(f"Optimizer {params.m_optimizer} not found")
        self.m_beta = params.m_beta
        self.smooth = params.smooth
        self.wgan = params.wgan
        self.d_clip_mode = params.d_clip_mode
        if params.wgan:
            self.loss_fn = _wasserstein_distance
        else:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean")
        self.sampler = [
            WordSampler(self.dic[id],
                        n_urns=n_samples,
                        alpha=params.a_sample_factor,
                        top=params.a_sample_top) for id in [0, 1]
        ]
        self.d_bs = params.d_bs
        self.d_gp = params.d_gp

    def get_adv_batch(self, *, reverse, gp=False):
        batch = [
            torch.LongTensor(
                [self.sampler[id].sample()
                 for _ in range(self.d_bs)]).view(self.d_bs, 1).to(GPU)
            for id in [0, 1]
        ]
        with torch.no_grad():
            x = [
                self.embedding[id](batch[id]).view(self.d_bs, -1)
                for id in [0, 1]
            ]
        y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth)
        if reverse:
            y[:self.d_bs] = 1 - y[:self.d_bs]
        else:
            y[self.d_bs:] = 1 - y[self.d_bs:]
        x[0] = self.mapping(x[0])
        if gp:
            t = torch.FloatTensor(self.d_bs,
                                  1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0])
            z = x[0] * t + x[1] * (1.0 - t)
            x = torch.cat(x, 0)
            return x, y, z
        else:
            x = torch.cat(x, 0)
            return x, y

    def adversarial_step(self):
        self.m_optimizer.zero_grad()
        self.discriminator.eval()
        x, y = self.get_adv_batch(reverse=True)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        self.m_optimizer.step()
        self.mapping.clip_weights()
        return loss.item()

    def discriminator_step(self):
        self.d_optimizer.zero_grad()
        self.discriminator.train()
        with torch.no_grad():
            if self.d_gp > 0:
                x, y, z = self.get_adv_batch(reverse=False, gp=True)
            else:
                x, y = self.get_adv_batch(reverse=False)
                z = None
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        if self.d_gp > 0:
            z.requires_grad_()
            z_out = self.discriminator(z)
            g = autograd.grad(z_out,
                              z,
                              grad_outputs=torch.ones_like(z_out, device=GPU),
                              retain_graph=True,
                              create_graph=True,
                              only_inputs=True)[0]
            gp = torch.mean((g.norm(p=2, dim=1) - 1.0)**2)
            loss += self.d_gp * gp
        loss.backward()
        self.d_optimizer.step()
        if self.wgan:
            self.discriminator.clip_weights(self.d_clip_mode)
        return loss.item()

    def scheduler_step(self, metric):
        self.m_scheduler.step(metric)
Exemple #22
0
class Trainer:
    def __init__(self, corpus_data_0, corpus_data_1, *, params, n_samples=10000000):
        self.fast_text = [FastText(corpus_data_0.model).to(GPU), FastText(corpus_data_1.model).to(GPU)]
        self.discriminator = Discriminator(params.emb_dim, n_layers=params.d_n_layers, n_units=params.d_n_units,
                                           drop_prob=params.d_drop_prob, drop_prob_input=params.d_drop_prob_input,
                                           leaky=params.d_leaky, batch_norm=params.d_bn).to(GPU)
        self.mapping = nn.Linear(params.emb_dim, params.emb_dim, bias=False)
        self.mapping.weight.data.copy_(torch.diag(torch.ones(params.emb_dim)))
        self.mapping = self.mapping.to(GPU)
        self.ft_optimizer, self.ft_scheduler = [], []
        for id in [0, 1]:
            optimizer, scheduler = optimizers.get_sgd_adapt(self.fast_text[id].parameters(),
                                                            lr=params.ft_lr, mode="max", factor=params.ft_lr_decay,
                                                            patience=params.ft_lr_patience)
            self.ft_optimizer.append(optimizer)
            self.ft_scheduler.append(scheduler)
        self.a_optimizer, self.a_scheduler = [], []
        for id in [0, 1]:
            optimizer, scheduler = optimizers.get_sgd_adapt(
                [{"params": self.fast_text[id].u.parameters()}, {"params": self.fast_text[id].v.parameters()}],
                lr=params.a_lr, mode="max", factor=params.a_lr_decay, patience=params.a_lr_patience)
            self.a_optimizer.append(optimizer)
            self.a_scheduler.append(scheduler)
        if params.d_optimizer == "SGD":
            self.d_optimizer, self.d_scheduler = optimizers.get_sgd_adapt(self.discriminator.parameters(),
                                                                          lr=params.d_lr, mode="max", wd=params.d_wd)

        elif params.d_optimizer == "RMSProp":
            self.d_optimizer, self.d_scheduler = optimizers.get_rmsprop_linear(self.discriminator.parameters(),
                                                                               params.n_steps,
                                                                               lr=params.d_lr, wd=params.d_wd)
        else:
            raise Exception(f"Optimizer {params.d_optimizer} not found.")
        if params.m_optimizer == "SGD":
            self.m_optimizer, self.m_scheduler = optimizers.get_sgd_adapt(self.mapping.parameters(),
                                                                          lr=params.m_lr, mode="max", wd=params.m_wd,
                                                                          factor=params.m_lr_decay,
                                                                          patience=params.m_lr_patience)
        elif params.m_optimizer == "RMSProp":
            self.m_optimizer, self.m_scheduler = optimizers.get_rmsprop_linear(self.mapping.parameters(),
                                                                               params.n_steps,
                                                                               lr=params.m_lr, wd=params.m_wd)
        else:
            raise Exception(f"Optimizer {params.m_optimizer} not found")
        self.m_beta = params.m_beta
        self.smooth = params.smooth
        self.wgan = params.wgan
        self.d_clip_mode = params.d_clip_mode
        if params.wgan:
            self.loss_fn = _wasserstein_distance
        else:
            self.loss_fn = nn.BCEWithLogitsLoss(reduction="elementwise_mean")
        self.corpus_data_queue = [
            _data_queue(corpus_data_0, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences,
                        batch_size=params.ft_bs),
            _data_queue(corpus_data_1, n_threads=(params.n_threads + 1) // 2, n_sentences=params.n_sentences,
                        batch_size=params.ft_bs)
        ]
        self.sampler = [
            WordSampler(corpus_data_0.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top),
            WordSampler(corpus_data_1.dic, n_urns=n_samples, alpha=params.a_sample_factor, top=params.a_sample_top)]
        self.d_bs = params.d_bs
        self.dic_0, self.dic_1 = corpus_data_0.dic, corpus_data_1.dic
        self.d_gp = params.d_gp

    def fast_text_step(self):
        losses = []
        for id in [0, 1]:
            self.ft_optimizer[id].zero_grad()
            u_b, v_b = self.corpus_data_queue[id].__next__()
            s = self.fast_text[id](u_b, v_b)
            loss = FastText.loss_fn(s)
            loss.backward()
            self.ft_optimizer[id].step()
            losses.append(loss.item())
        return losses[0], losses[1]

    def get_adv_batch(self, *, reverse, fix_embedding=False, gp=False):
        batch = [[self.sampler[id].sample() for _ in range(self.d_bs)]
                 for id in [0, 1]]
        batch = [self.fast_text[id].model.get_bag(batch[id], self.fast_text[id].u.weight.device)
                 for id in [0, 1]]
        if fix_embedding:
            with torch.no_grad():
                x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]]
        else:
            x = [self.fast_text[id].u(batch[id][0], batch[id][1]).view(self.d_bs, -1) for id in [0, 1]]
        y = torch.FloatTensor(self.d_bs * 2).to(GPU).uniform_(0.0, self.smooth)
        if reverse:
            y[: self.d_bs] = 1 - y[: self.d_bs]
        else:
            y[self.d_bs:] = 1 - y[self.d_bs:]
        x[0] = self.mapping(x[0])
        if gp:
            t = torch.FloatTensor(self.d_bs, 1).to(GPU).uniform_(0.0, 1.0).expand_as(x[0])
            z = x[0] * t + x[1] * (1.0 - t)
            x = torch.cat(x, 0)
            return x, y, z
        else:
            x = torch.cat(x, 0)
            return x, y

    def adversarial_step(self, fix_embedding=False):
        for id in [0, 1]:
            self.a_optimizer[id].zero_grad()
        self.m_optimizer.zero_grad()
        self.discriminator.eval()
        x, y = self.get_adv_batch(reverse=True, fix_embedding=fix_embedding)
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        loss.backward()
        for id in [0, 1]:
            self.a_optimizer[id].step()
        self.m_optimizer.step()
        _orthogonalize(self.mapping, self.m_beta)
        return loss.item()

    def discriminator_step(self):
        self.d_optimizer.zero_grad()
        self.discriminator.train()
        with torch.no_grad():
            if self.d_gp > 0:
                x, y, z = self.get_adv_batch(reverse=False, gp=True)
            else:
                x, y = self.get_adv_batch(reverse=False)
                z = None
        y_hat = self.discriminator(x)
        loss = self.loss_fn(y_hat, y)
        if self.d_gp > 0:
            z.requires_grad_()
            z_out = self.discriminator(z)
            g = autograd.grad(z_out, z, grad_outputs=torch.ones_like(z_out, device=GPU),
                              retain_graph=True, create_graph=True, only_inputs=True)[0]
            gp = torch.mean((g.norm(p=2, dim=1) - 1.0) ** 2)
            loss += self.d_gp * gp
        loss.backward()
        self.d_optimizer.step()
        if self.wgan:
            self.discriminator.clip_weights(self.d_clip_mode)
        return loss.item()

    def scheduler_step(self, metric):
        for id in [0, 1]:
            self.ft_scheduler[id].step(metric)
            self.a_scheduler[id].step(metric)
        # self.d_scheduler.step(metric)
        self.m_scheduler.step(metric)
Exemple #23
0
def main(args):
    # log hyperparameter
    print(args)

    # select device
    args.cuda = not args.no_cuda and torch.cuda.is_available()
    device = torch.device("cuda: 0" if args.cuda else "cpu")

    # set random seed
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # data loader
    transform = transforms.Compose([
        utils.Normalize(),
        utils.ToTensor()
    ])
    train_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_train_list,
        max_k=args.training_step,
        train=True,
        transform=transform
    )
    test_dataset = TVDataset(
        root=args.root,
        sub_size=args.block_size,
        volume_list=args.volume_test_list,
        max_k=args.training_step,
        train=False,
        transform=transform
    )

    kwargs = {"num_workers": 4, "pin_memory": True} if args.cuda else {}
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size,
                              shuffle=True, **kwargs)
    test_loader = DataLoader(test_dataset, batch_size=args.batch_size,
                             shuffle=False, **kwargs)

    # model
    def generator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def discriminator_weights_init(m):
        if isinstance(m, nn.Conv3d):
            nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    g_model = Generator(args.upsample_mode, args.forward, args.backward, args.gen_sn, args.residual)
    g_model.apply(generator_weights_init)
    if args.data_parallel and torch.cuda.device_count() > 1:
        g_model = nn.DataParallel(g_model)
    g_model.to(device)

    if args.gan_loss != "none":
        d_model = Discriminator(args.dis_sn)
        d_model.apply(discriminator_weights_init)
        # if args.dis_sn:
        #     d_model = add_sn(d_model)
        if args.data_parallel and torch.cuda.device_count() > 1:
            d_model = nn.DataParallel(d_model)
        d_model.to(device)

    mse_loss = nn.MSELoss()
    adversarial_loss = nn.MSELoss()
    train_losses, test_losses = [], []
    d_losses, g_losses = [], []

    # optimizer
    g_optimizer = optim.Adam(g_model.parameters(), lr=args.lr,
                             betas=(args.beta1, args.beta2))
    if args.gan_loss != "none":
        d_optimizer = optim.Adam(d_model.parameters(), lr=args.d_lr,
                                 betas=(args.beta1, args.beta2))

    Tensor = torch.cuda.FloatTensor if args.cuda else torch.FloatTensor

    # load checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint {}".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint["epoch"]
            g_model.load_state_dict(checkpoint["g_model_state_dict"])
            # g_optimizer.load_state_dict(checkpoint["g_optimizer_state_dict"])
            if args.gan_loss != "none":
                d_model.load_state_dict(checkpoint["d_model_state_dict"])
                # d_optimizer.load_state_dict(checkpoint["d_optimizer_state_dict"])
                d_losses = checkpoint["d_losses"]
                g_losses = checkpoint["g_losses"]
            train_losses = checkpoint["train_losses"]
            test_losses = checkpoint["test_losses"]
            print("=> load chekcpoint {} (epoch {})"
                  .format(args.resume, checkpoint["epoch"]))

    # main loop
    for epoch in tqdm(range(args.start_epoch, args.epochs)):
        # training..
        g_model.train()
        if args.gan_loss != "none":
            d_model.train()
        train_loss = 0.
        volume_loss_part = np.zeros(args.training_step)
        for i, sample in enumerate(train_loader):
            params = list(g_model.named_parameters())
            # pdb.set_trace()
            # params[0][1].register_hook(lambda g: print("{}.grad: {}".format(params[0][0], g)))
            # adversarial ground truths
            real_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(1.0), requires_grad=False)
            fake_label = Variable(Tensor(sample["v_i"].shape[0], sample["v_i"].shape[1], 1, 1, 1, 1).fill_(0.0), requires_grad=False)

            v_f = sample["v_f"].to(device)
            v_b = sample["v_b"].to(device)
            v_i = sample["v_i"].to(device)
            g_optimizer.zero_grad()
            fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)

            # adversarial loss
            # update discriminator
            if args.gan_loss != "none":
                avg_d_loss = 0.
                avg_d_loss_real = 0.
                avg_d_loss_fake = 0.
                for k in range(args.n_d):
                    d_optimizer.zero_grad()
                    decisions = d_model(v_i)
                    d_loss_real = adversarial_loss(decisions, real_label)
                    fake_decisions = d_model(fake_volumes.detach())

                    d_loss_fake = adversarial_loss(fake_decisions, fake_label)
                    d_loss = d_loss_real + d_loss_fake
                    d_loss.backward()
                    avg_d_loss += d_loss.item() / args.n_d
                    avg_d_loss_real += d_loss_real / args.n_d
                    avg_d_loss_fake += d_loss_fake / args.n_d

                    d_optimizer.step()

            # update generator
            if args.gan_loss != "none":
                avg_g_loss = 0.
            avg_loss = 0.
            for k in range(args.n_g):
                loss = 0.
                g_optimizer.zero_grad()

                # adversarial loss
                if args.gan_loss != "none":
                    fake_decisions = d_model(fake_volumes)
                    g_loss = args.gan_loss_weight * adversarial_loss(fake_decisions, real_label)
                    loss += g_loss
                    avg_g_loss += g_loss.item() / args.n_g

                # volume loss
                if args.volume_loss:
                    volume_loss = args.volume_loss_weight * mse_loss(v_i, fake_volumes)
                    for j in range(v_i.shape[1]):
                        volume_loss_part[j] += mse_loss(v_i[:, j, :], fake_volumes[:, j, :]) / args.n_g / args.log_every
                    loss += volume_loss

                # feature loss
                if args.feature_loss:
                    feat_real = d_model.extract_features(v_i)
                    feat_fake = d_model.extract_features(fake_volumes)
                    for m in range(len(feat_real)):
                        loss += args.feature_loss_weight / len(feat_real) * mse_loss(feat_real[m], feat_fake[m])

                avg_loss += loss / args.n_g
                loss.backward()
                g_optimizer.step()

            train_loss += avg_loss

            # log training status
            subEpoch = (i + 1) // args.log_every
            if (i+1) % args.log_every == 0:
                print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
                    epoch, (i+1) * args.batch_size, len(train_loader.dataset), 100. * (i+1) / len(train_loader),
                    avg_loss
                ))
                print("Volume Loss: ")
                for j in range(volume_loss_part.shape[0]):
                    print("\tintermediate {}: {:.6f}".format(
                        j+1, volume_loss_part[j]
                    ))

                if args.gan_loss != "none":
                    print("DLossReal: {:.6f} DLossFake: {:.6f} DLoss: {:.6f}, GLoss: {:.6f}".format(
                        avg_d_loss_real, avg_d_loss_fake, avg_d_loss, avg_g_loss
                    ))
                    d_losses.append(avg_d_loss)
                    g_losses.append(avg_g_loss)
                # train_losses.append(avg_loss)
                train_losses.append(train_loss.item() / args.log_every)
                print("====> SubEpoch: {} Average loss: {:.6f} Time {}".format(
                    subEpoch, train_loss.item() / args.log_every, time.asctime(time.localtime(time.time()))
                ))
                train_loss = 0.
                volume_loss_part = np.zeros(args.training_step)

            # testing...
            if (i + 1) % args.test_every == 0:
                g_model.eval()
                if args.gan_loss != "none":
                    d_model.eval()
                test_loss = 0.
                with torch.no_grad():
                    for i, sample in enumerate(test_loader):
                        v_f = sample["v_f"].to(device)
                        v_b = sample["v_b"].to(device)
                        v_i = sample["v_i"].to(device)
                        fake_volumes = g_model(v_f, v_b, args.training_step, args.wo_ori_volume, args.norm)
                        test_loss += args.volume_loss_weight * mse_loss(v_i, fake_volumes).item()

                test_losses.append(test_loss * args.batch_size / len(test_loader.dataset))
                print("====> SubEpoch: {} Test set loss {:4f} Time {}".format(
                    subEpoch, test_losses[-1], time.asctime(time.localtime(time.time()))
                ))

            # saving...
            if (i+1) % args.check_every == 0:
                print("=> saving checkpoint at epoch {}".format(epoch))
                if args.gan_loss != "none":
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict":  g_optimizer.state_dict(),
                                "d_model_state_dict": d_model.state_dict(),
                                "d_optimizer_state_dict": d_optimizer.state_dict(),
                                "d_losses": d_losses,
                                "g_losses": g_losses,
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                else:
                    torch.save({"epoch": epoch + 1,
                                "g_model_state_dict": g_model.state_dict(),
                                "g_optimizer_state_dict": g_optimizer.state_dict(),
                                "train_losses": train_losses,
                                "test_losses": test_losses},
                               os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + "_" + "pth.tar")
                               )
                torch.save(g_model.state_dict(),
                           os.path.join(args.save_dir, "model_" + str(epoch) + "_" + str(subEpoch) + ".pth"))

        num_subEpoch = len(train_loader) // args.log_every
        print("====> Epoch: {} Average loss: {:.6f} Time {}".format(
            epoch, np.array(train_losses[-num_subEpoch:]).mean(), time.asctime(time.localtime(time.time()))
        ))
Exemple #24
0
def main():
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)
    gen_data_loader = Gen_Data_loader(BATCH_SIZE)
    likelihood_data_loader = Gen_Data_loader(BATCH_SIZE)  # For testing
    vocab_size = 2000
    dis_data_loader = Dis_dataloader(BATCH_SIZE)

    generator = Generator(vocab_size, EMB_DIM, HIDDEN_DIM, 1, START_TOKEN,
                          SEQ_LENGTH).to(device)
    target_lstm = Generator(vocab_size,
                            EMB_DIM,
                            HIDDEN_DIM,
                            1,
                            START_TOKEN,
                            SEQ_LENGTH,
                            oracle=True).to(device)
    discriminator = Discriminator(vocab_size, dis_embedding_dim,
                                  dis_filter_sizes, dis_num_filters,
                                  dis_dropout).to(device)

    generate_samples(target_lstm, BATCH_SIZE, generated_num, positive_file)
    gen_data_loader.create_batches(positive_file)

    pre_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2)
    adv_gen_opt = torch.optim.Adam(generator.parameters(), 1e-2)
    dis_opt = torch.optim.Adam(discriminator.parameters(), 1e-4)
    dis_criterion = nn.NLLLoss()

    log = open('save/experiment-log.txt', 'w')
    print('Start pre-training...')
    log.write('pre-training...\n')
    for epoch in range(PRE_EPOCH_NUM):
        loss = pre_train_epoch(generator, pre_gen_opt, gen_data_loader)
        if (epoch + 1) % 5 == 0:
            generate_samples(generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(target_lstm, likelihood_data_loader)
            print('pre-train epoch ', epoch + 1, '\tnll:\t', test_loss)
            buffer = 'epoch:\t' + str(epoch +
                                      1) + '\tnll:\t' + str(test_loss) + '\n'
            log.write(buffer)
    print('Start pre-training discriminator...')
    # Train 3 epoch on the generated data and do this for 50 times
    for e in range(50):
        generate_samples(generator, BATCH_SIZE, generated_num, negative_file)
        dis_data_loader.load_train_data(positive_file, negative_file)
        d_total_loss = []
        for _ in range(3):
            dis_data_loader.reset_pointer()
            total_loss = []
            for it in range(dis_data_loader.num_batch):
                x_batch, y_batch = dis_data_loader.next_batch()
                x_batch = x_batch.to(device)
                y_batch = y_batch.to(device)
                dis_output = discriminator(x_batch.detach())
                d_loss = dis_criterion(dis_output, y_batch.detach())
                dis_opt.zero_grad()
                d_loss.backward()
                dis_opt.step()
                total_loss.append(d_loss.data.cpu().numpy())
            d_total_loss.append(np.mean(total_loss))
        if (e + 1) % 5 == 0:
            buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format(
                e + 1, np.mean(d_total_loss))
            print(buffer)
            log.write(buffer)

    rollout = Rollout(generator, 0.8)
    print(
        '#########################################################################'
    )
    print('Start Adversarial Training...')
    log.write('adversarial training...\n')
    gan_loss = GANLoss()
    for total_batch in range(TOTAL_BATCH):
        # Train the generator for one step
        discriminator.eval()
        for it in range(1):
            samples, _ = generator.sample(num_samples=BATCH_SIZE)
            rewards = rollout.get_reward(samples, 16, discriminator)
            prob = generator(samples.detach())
            adv_loss = gan_loss(prob, samples.detach(), rewards.detach())
            adv_gen_opt.zero_grad()
            adv_loss.backward()
            nn.utils.clip_grad_norm_(generator.parameters(), 5.0)
            adv_gen_opt.step()

        # Test
        if (total_batch + 1) % 5 == 0:
            generate_samples(generator, BATCH_SIZE, generated_num, eval_file)
            likelihood_data_loader.create_batches(eval_file)
            test_loss = target_loss(target_lstm, likelihood_data_loader)
            self_bleu_score = self_bleu(generator)
            buffer = 'epoch:\t' + str(total_batch + 1) + '\tnll:\t' + str(
                test_loss) + '\tSelf Bleu:\t' + str(self_bleu_score) + '\n'
            print(buffer)
            log.write(buffer)

        # Update roll-out parameters
        rollout.update_params()

        # Train the discriminator
        discriminator.train()
        for _ in range(5):
            generate_samples(generator, BATCH_SIZE, generated_num,
                             negative_file)
            dis_data_loader.load_train_data(positive_file, negative_file)
            d_total_loss = []
            for _ in range(3):
                dis_data_loader.reset_pointer()
                total_loss = []
                for it in range(dis_data_loader.num_batch):
                    x_batch, y_batch = dis_data_loader.next_batch()
                    x_batch = x_batch.to(device)
                    y_batch = y_batch.to(device)
                    dis_output = discriminator(x_batch.detach())
                    d_loss = dis_criterion(dis_output, y_batch.detach())
                    dis_opt.zero_grad()
                    d_loss.backward()
                    dis_opt.step()
                    total_loss.append(d_loss.data.cpu().numpy())
                d_total_loss.append(np.mean(total_loss))
            if (total_batch + 1) % 5 == 0:
                buffer = 'Epoch [{}], discriminator loss [{:.4f}]\n'.format(
                    total_batch + 1, np.mean(d_total_loss))
                print(buffer)
                log.write(buffer)
    log.close()
Exemple #25
0
class Trainer:
    def __init__(self):
        self.logger = None
        self.tester = None
        self.latent = None
        self.save_path = None
        self.epoch = 0
        self.step = 0
        self.loss_computer = None
        self.tf_prob = 0
        # Models
        self.encoder = None
        self.latent_compressor = None
        self.latent_decompressor = None
        self.decoder = None
        self.generator = None
        if config["train"]["aae"]:
            self.discriminator = None
        # Optimizers
        self.encoder_optimizer = None
        self.decoder_optimizer = None
        self.criterion = None
        if config["train"]["aae"]:
            self.disc_optimizer = None
            self.gen_optimizer = None
            self.train_discriminator_not_generator = True
            self.disc_losses = []
            self.gen_losses = []
            self.disc_loss_init = None
            self.gen_loss_init = None
            self.beta = -0.1  # so it become 0 at first iteration
            self.reg_optimizer = None

    def test_losses(self, loss):
        losses = [loss]
        names = ["loss"]
        for ls, name in zip(losses, names):
            print("********************** Optimized by " + name)
            self.encoder_optimizer.zero_grad(set_to_none=True)
            self.decoder_optimizer.zero_grad(set_to_none=True)
            ls.backward(retain_graph=True)
            for model in [
                    self.encoder, self.latent_compressor,
                    self.latent_decompressor, self.decoder, self.generator
            ]:  # removed latent compressor
                for module_name, parameter in model.named_parameters():
                    if parameter.grad is not None:
                        print(module_name)
        self.encoder_optimizer.zero_grad(set_to_none=True)
        self.decoder_optimizer.zero_grad(set_to_none=True)
        (losses[0]).backward(retain_graph=True)
        print("********************** NOT OPTIMIZED BY NOTHING")
        for model in [
                self.encoder, self.latent_compressor, self.latent_decompressor,
                self.decoder, self.generator
        ]:  # removed latent compressor
            for module_name, parameter in model.named_parameters():
                if parameter.grad is None:
                    print(module_name)

    def run_mb(self, batch):
        # SETUP VARIABLES
        srcs, trgs = batch
        srcs = torch.LongTensor(srcs.long()).to(
            config["train"]["device"]).transpose(0, 2)
        trgs = torch.LongTensor(trgs.long()).to(
            config["train"]["device"]).transpose(0, 2)  # invert batch and bars

        latent = None
        batches = [
            Batch(srcs[i], trgs[i], config["tokens"]["pad"])
            for i in range(n_bars)
        ]
        ############
        # ENCODING #
        ############
        latents = []
        for batch in batches:
            latent = self.encoder(batch.src, batch.src_mask)
            latents.append(latent)

        ############
        # COMPRESS #
        ############
        old_batches = copy.deepcopy(batches)
        if config["train"]["compress_latents"]:
            latent = self.latent_compressor(
                latents)  # in: 3, 4, 200, 256, out: 3, 256

        self.latent = latent.detach().cpu().numpy()

        if config["train"]["compress_latents"]:
            latents = self.latent_decompressor(
                latent)  # in 3, 256, out: 3, 4, 200, 256
            for i in range(n_bars):
                batches[i].src_mask = batches[i].src_mask.fill_(
                    True)[:, :, :, :20]

        ############
        # DECODING #
        ############
        # Scheduled sampling for transformer
        if config["train"]["scheduled_sampling"] and self.step > config[
                "train"]["after_steps_mix_sequences"]:
            for _ in range(1):  # K
                self.tf_prob = 0.5

                predicted = []
                for batch, latent in zip(batches, latents):
                    out = self.decoder(batch.trg, latent, batch.src_mask,
                                       batch.trg_mask)
                    prob = self.generator(out)
                    prob = torch.max(prob, dim=-1).indices
                    predicted.append(prob)

                # add sos at beginning and cut last token
                for i in range(n_bars):
                    sos = torch.full_like(predicted[i],
                                          config["tokens"]["sos"])[..., :1].to(
                                              predicted[i].device)
                    pred = torch.cat((sos, predicted[i]), dim=-1)[..., :-1]
                    # create mixed trg
                    mixed_prob = torch.rand(batches[i].trg.shape,
                                            dtype=torch.float32).to(
                                                trgs.device)
                    mixed_prob = mixed_prob < self.tf_prob
                    batches[i].trg = batches[i].trg.where(mixed_prob, pred)

        outs = []
        for batch, latent in zip(batches, latents):
            out = self.decoder(batch.trg, latent, batch.src_mask,
                               batch.trg_mask)
            outs.append(out)

        # Format results
        outs = torch.stack(outs, dim=0)

        #####################
        # LOSS AND ACCURACY #
        #####################
        trg_ys = torch.stack([batch.trg_y for batch in batches])
        bars, n_track, n_batch, seq_len, d_model = outs.shape
        outs = outs.permute(1, 2, 0, 3,
                            4).reshape(n_track, n_batch, bars * seq_len,
                                       d_model)  # join bars
        trg_ys = trg_ys.permute(1, 2, 0, 3).reshape(n_track, n_batch,
                                                    bars * seq_len)

        loss, accuracy = SimpleLossCompute(self.generator, self.criterion)(
            outs, trg_ys, batch.ntokens)  # join instr

        # if self.encoder.training:
        #     self.test_losses(loss)

        if self.generator.training:
            self.encoder_optimizer.zero_grad()
            self.decoder_optimizer.zero_grad()

            # if n_bars == 16:
            #     torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), 0.1)
            #     torch.nn.utils.clip_grad_norm_(self.latent_compressor.parameters(), 0.1)
            #     torch.nn.utils.clip_grad_norm_(self.latent_decompressor.parameters(), 0.1)
            #     torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), 0.1)
            #     torch.nn.utils.clip_grad_norm_(self.generator.parameters(), 0.1)

            loss.backward()
            self.encoder_optimizer.step()
            self.decoder_optimizer.step()

        losses = (loss.item(), accuracy, 0, 0, 0, 0)  # *loss_items)

        # LOG IMAGES
        if True and self.encoder.training and config["train"]["log_images"] and \
                self.step % config["train"]["after_steps_log_images"] == 0 and self.step > 0:

            # # ENCODER SELF
            drums_encoder_attn = []
            for layer in self.encoder.drums_encoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                drums_encoder_attn.append(instrument_attn)

            bass_encoder_attn = []
            for layer in self.encoder.bass_encoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                bass_encoder_attn.append(instrument_attn)

            guitar_encoder_attn = []
            for layer in self.encoder.guitar_encoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                guitar_encoder_attn.append(instrument_attn)

            strings_encoder_attn = []
            for layer in self.encoder.strings_encoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                strings_encoder_attn.append(instrument_attn)

            enc_attention = [
                drums_encoder_attn, guitar_encoder_attn, bass_encoder_attn,
                strings_encoder_attn
            ]

            # DECODER SELF
            drums_decoder_attn = []
            for layer in self.decoder.drums_decoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                drums_decoder_attn.append(instrument_attn)

            bass_decoder_attn = []
            for layer in self.decoder.bass_decoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                bass_decoder_attn.append(instrument_attn)

            guitar_decoder_attn = []
            for layer in self.decoder.guitar_decoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                guitar_decoder_attn.append(instrument_attn)

            strings_decoder_attn = []
            for layer in self.decoder.strings_decoder.layers:
                instrument_attn = []
                for head in layer.self_attn.attn[0]:
                    instrument_attn.append(head)
                strings_decoder_attn.append(instrument_attn)

            dec_attention = [
                drums_decoder_attn, guitar_decoder_attn, bass_decoder_attn,
                strings_decoder_attn
            ]
            # DECODER SRC
            drums_src_attn = []
            for layer in self.decoder.drums_decoder.layers:
                instrument_attn = []
                for head in layer.src_attn.attn[0]:
                    instrument_attn.append(head)
                drums_src_attn.append(instrument_attn)

            bass_src_attn = []
            for layer in self.decoder.bass_decoder.layers:
                instrument_attn = []
                for head in layer.src_attn.attn[0]:
                    instrument_attn.append(head)
                bass_src_attn.append(instrument_attn)

            guitar_src_attn = []
            for layer in self.decoder.guitar_decoder.layers:
                instrument_attn = []
                for head in layer.src_attn.attn[0]:
                    instrument_attn.append(head)
                guitar_src_attn.append(instrument_attn)

            strings_src_attn = []
            for layer in self.decoder.strings_decoder.layers:
                instrument_attn = []
                for head in layer.src_attn.attn[0]:
                    instrument_attn.append(head)
                strings_src_attn.append(instrument_attn)

            src_attention = [
                drums_src_attn, guitar_src_attn, bass_src_attn,
                strings_src_attn
            ]
            print("Logging images...")
            if config["train"]["compress_latents"]:
                self.logger.log_latent(self.latent)
            self.logger.log_attn_heatmap(enc_attention, dec_attention,
                                         src_attention)
            self.logger.log_examples(srcs, trgs)

        ####################
        # UPDATE GENERATOR #
        ####################
        if config["train"][
                "aae"] and self.encoder.training and self.step > config[
                    "train"]["after_steps_train_aae"]:

            if self.step % config["train"][
                    "increase_beta_every"] == 0 and self.beta < config[
                        "train"]["max_beta"]:
                self.beta += 0.1

            if self.beta > 0:
                # To suppress warnings
                D_real = 0
                D_fake = 0
                loss_critic = 0

                ########################
                # UPDATE DISCRIMINATOR #
                ########################
                for p in self.encoder.parameters():
                    p.requires_grad = False
                for p in self.latent_compressor.parameters():
                    p.requires_grad = False
                for p in self.discriminator.parameters():
                    p.requires_grad = True

                latents = []
                for batch in old_batches:
                    latent = self.encoder(batch.src, batch.src_mask)
                    latents.append(latent)
                latent = self.latent_compressor(latents)

                for _ in range(config["train"]["critic_iterations"]):
                    prior = get_prior(
                        (config["train"]["batch_size"],
                         config["model"]["d_model"]))  # autograd is intern
                    D_real = self.discriminator(prior).reshape(-1)

                    D_fake = self.discriminator(latent).reshape(-1)

                    gradient_penalty = calc_gradient_penalty(
                        self.discriminator, prior.data, latent.data)

                    loss_critic = (
                        torch.mean(D_fake) - torch.mean(D_real) +
                        config["train"]["lambda"] * gradient_penalty)
                    loss_critic = loss_critic * self.beta

                    self.discriminator.zero_grad()
                    loss_critic.backward(retain_graph=True)

                    self.disc_optimizer.step(lr=self.encoder_optimizer.lr)

                ####################
                # UPDATE GENERATOR #
                ####################
                for p in self.encoder.parameters():
                    p.requires_grad = True
                for p in self.latent_compressor.parameters():
                    p.requires_grad = True
                for p in self.discriminator.parameters():
                    p.requires_grad = False  # to avoid computation

                latents = []
                for batch in old_batches:
                    latent = self.encoder(batch.src, batch.src_mask)
                    latents.append(latent)
                latent = self.latent_compressor(latents)

                G = self.discriminator(latent).reshape(-1)

                loss_gen = -torch.mean(G)
                loss_gen = loss_gen * self.beta

                self.gen_optimizer.zero_grad()

                loss_gen.backward()

                self.gen_optimizer.step(lr=self.encoder_optimizer.lr)

                losses += (D_real.mean().cpu().data.numpy(),
                           D_fake.mean().cpu().data.numpy(),
                           G.mean().cpu().data.numpy(),
                           loss_critic.cpu().data.numpy(),
                           loss_gen.cpu().data.numpy(),
                           D_real.mean().cpu().data.numpy() -
                           D_fake.mean().cpu().data.numpy())

        return losses

    def train(self):
        # Create checkpoint folder
        if not os.path.exists(config["paths"]["checkpoints"]):
            os.makedirs(config["paths"]["checkpoints"])
        timestamp = str(datetime.now())
        timestamp = timestamp[:timestamp.index('.')]
        timestamp = timestamp.replace(' ', '_').replace(':', '-')
        self.save_path = config["paths"]["checkpoints"] + os.sep + timestamp
        os.mkdir(self.save_path)

        # Create models
        self.latent_compressor = LatentCompressor(
            config["model"]["d_model"]).to(config["train"]["device"])
        self.latent_decompressor = LatentDecompressor(
            config["model"]["d_model"]).to(config["train"]["device"])
        voc_size = config["tokens"]["vocab_size"]
        device = config["train"]["device"]
        self.encoder, self.decoder, self.generator = make_model(
            voc_size, voc_size, N=config["model"]["layers"], device=device)

        if config["train"]["aae"]:
            self.discriminator = Discriminator(
                config["model"]["d_model"],
                config["model"]["discriminator_dropout"]).to(
                    config["train"]["device"])

        # Create optimizers
        enc_params = list(self.encoder.parameters()) + list(
            self.latent_compressor.parameters())
        self.encoder_optimizer = CTOpt(
            torch.optim.Adam(enc_params, lr=0, betas=(0.9, 0.98)),
            config["train"]["warmup_steps"],
            (config["train"]["lr_min"], config["train"]["lr_max"]),
            config["train"]["decay_steps"], config["train"]["minimum_lr"])
        dec_params = list(self.latent_decompressor.parameters()) + list(
            self.decoder.parameters()) + list(self.generator.parameters())
        self.decoder_optimizer = CTOpt(
            torch.optim.Adam(dec_params, lr=0, betas=(0.9, 0.98)),
            config["train"]["warmup_steps"],
            (config["train"]["lr_min"], config["train"]["lr_max"]),
            config["train"]["decay_steps"], config["train"]["minimum_lr"])

        if config["train"]["aae"]:
            self.disc_optimizer = CTOpt(
                torch.optim.Adam([{
                    "params": self.discriminator.parameters()
                }],
                                 lr=0,
                                 betas=(0.9, 0.98)),
                config["train"]["warmup_steps"],
                (config["train"]["lr_min"], config["train"]["lr_max"]),
                config["train"]["decay_steps"], config["train"]["minimum_lr"])
            self.gen_optimizer = CTOpt(
                torch.optim.Adam(enc_params, lr=0, betas=(0.9, 0.98)),
                config["train"]["warmup_steps"],
                (config["train"]["lr_min"], config["train"]["lr_max"]),
                config["train"]["decay_steps"], config["train"]["minimum_lr"])
        self.criterion = LabelSmoothing(size=config["tokens"]["vocab_size"],
                                        padding_idx=0,
                                        smoothing=0.1).to(device)

        # Load dataset
        tr_loader = SongIterator(
            dataset_path=config["paths"]["dataset"] + os.sep + "train",
            batch_size=config["train"]["batch_size"],
            n_workers=config["train"]["n_workers"]).get_loader()
        ts_loader = SongIterator(
            dataset_path=config["paths"]["dataset"] + os.sep + "eval",
            batch_size=config["train"]["batch_size"],
            n_workers=config["train"]["n_workers"]).get_loader()

        # Init WANDB
        self.logger = Logger()
        wandb.login()
        wandb.init(project="MusAE",
                   config=config,
                   name="r_" + timestamp if remote else "l_" + timestamp)
        wandb.watch(self.encoder, log_freq=1000, log="all")
        wandb.watch(self.latent_compressor, log_freq=1000, log="all")
        wandb.watch(self.latent_decompressor, log_freq=1000, log="all")
        wandb.watch(self.decoder, log_freq=1000, log="all")
        wandb.watch(self.generator, log_freq=1000, log="all")
        if config["train"]["aae"]:
            wandb.watch(self.discriminator, log_freq=1000, log="all")

        # Print info about training
        time.sleep(
            1.)  # sleep for one second to let the machine connect to wandb
        if config["train"]["verbose"]:
            print("Giving", len(tr_loader), "training samples and",
                  len(ts_loader), "test samples")
            # print("Final set has size", len(dataset.final_set))
            print("Model has", config["model"]["layers"], "layers")
            print("Batch size is", config["train"]["batch_size"])
            print("d_model is", config["model"]["d_model"])
            if config["train"]["aae"]:
                print("Imposing prior distribution on latents")
                print("Starting training aae after",
                      config["train"]["train_aae_after_steps"])
                print("lambda:", config["train"]["lambda"],
                      ", critic iterations:",
                      config["train"]["critic_iterations"])
            else:
                print("NOT imposing prior distribution on latents")
            if config["train"]["log_images"]:
                print("Logging images")
            else:
                print("NOT logging images")
            if config["train"]["make_songs"]:
                print("Making songs every",
                      config["train"]["after_steps_make_songs"])
            else:
                print("NOT making songs")
            if config["train"]["do_eval"]:
                if config["train"]["eval_after_epoch"]:
                    print("Doing evaluation after each epoch")
                else:
                    print("Doing evaluation after",
                          config["train"]["after_steps_do_eval"])
            else:
                print("NOT DOING evaluation")
            if config["train"]["scheduled_sampling"]:
                print("Using scheduled sampling")
            else:
                print("NOT using scheduled sampling")
            if config["train"]["compress_latents"]:
                print("Compressing latents")
            else:
                print("NOT compressing latents")
            if config["train"]["use_rel_pos"]:
                print("Using relative positional encoding")
            else:
                print("NOT using relative positional encoding")
            print("Save model every",
                  config["train"]["after_steps_save_model"])
            if remote:
                wandb.save("compress_latents.py")
                wandb.save("train.py")
                wandb.save("config.py")
                wandb.save("test.py")
                wandb.save("loss_computer.py")
                wandb.save("utilities.py")
                wandb.save("discriminator.py")
                wandb.save("compressive_transformer.py")

        # Setup train
        self.encoder.train()
        self.latent_compressor.train()
        self.latent_decompressor.train()
        self.decoder.train()
        self.generator.train()
        if config["train"]["aae"]:
            self.discriminator.train()
        desc = "Train epoch " + str(self.epoch) + ", mb " + str(0)
        if config["train"]["eval_after_epoch"]:
            train_progress = tqdm(total=len(tr_loader),
                                  position=0,
                                  leave=True,
                                  desc=desc)
        else:
            train_progress = tqdm(total=config["train"]["after_steps_do_eval"],
                                  position=0,
                                  leave=True,
                                  desc=desc)
        self.step = 0  # -1 to do eval in first step
        first_batch = None

        # Main loop
        for self.epoch in range(config["train"]["n_epochs"]):  # for each epoch
            for song_it, batch in enumerate(tr_loader):  # for each song

                #########
                # TRAIN #
                #########
                if first_batch is None:  # if training reconstruct from train, if eval reconstruct from eval
                    first_batch = batch
                second_batch = batch
                tr_losses = self.run_mb(batch)

                if self.step % 10 == 0:
                    self.logger.log_losses(tr_losses, self.encoder.training)
                    self.logger.log_stuff(
                        self.encoder_optimizer.lr, self.latent,
                        self.disc_optimizer.lr if config["train"]["aae"] else
                        None, self.gen_optimizer.lr
                        if config["train"]["aae"] else None,
                        self.beta if config["train"]["aae"] else None,
                        get_prior(self.latent.shape)
                        if config["train"]["aae"] else None, self.tf_prob)
                if self.step == 0:
                    print("Latent shape is:", self.latent.shape)
                train_progress.update()

                ########
                # EVAL #
                ########
                eae = config["train"]["eval_after_epoch"]
                do_eval = config["train"]["do_eval"]
                sbe = config["train"]["after_steps_do_eval"]
                if ((eae and song_it == 0) or
                    (not eae
                     and self.step % sbe == 0)) and do_eval and self.step > 0:
                    print("Evaluation")
                    train_progress.close()
                    ts_losses = []

                    self.encoder.eval()
                    self.latent_compressor.eval()
                    self.latent_decompressor.eval()
                    self.decoder.eval()
                    self.generator.eval()

                    if config["train"]["aae"]:
                        self.discriminator.eval()
                    desc = "Eval epoch " + str(
                        self.epoch) + ", mb " + str(song_it)

                    # Compute validation score
                    first_batch = None
                    for test in tqdm(ts_loader,
                                     position=0,
                                     leave=True,
                                     desc=desc):  # remember test losses
                        if first_batch is None:
                            first_batch = test
                        second_batch = test
                        with torch.no_grad():
                            ts_loss = self.run_mb(test)
                        ts_losses.append(ts_loss)
                    final = ()  # average losses
                    for i in range(len(ts_losses[0])):  # for each loss value
                        aux = []
                        for loss in ts_losses:  # for each computed loss
                            aux.append(loss[i])
                        avg = sum(aux) / len(aux)
                        final = final + (avg, )
                    self.logger.log_losses(final, self.encoder.training)

                    # eval end
                    self.encoder.train()
                    self.latent_compressor.train()
                    self.latent_decompressor.train()
                    self.decoder.train()
                    self.generator.train()
                    if config["train"]["aae"]:
                        self.discriminator.train()
                    desc = "Train epoch " + str(
                        self.epoch) + ", mb " + str(song_it)
                    if config["train"]["eval_after_epoch"]:
                        train_progress = tqdm(total=len(tr_loader),
                                              position=0,
                                              leave=True,
                                              desc=desc)
                    else:
                        train_progress = tqdm(
                            total=config["train"]["after_steps_do_eval"],
                            position=0,
                            leave=True,
                            desc=desc)

                ##############
                # SAVE MODEL #
                ##############
                if (self.step % config["train"]["after_steps_save_model"]
                    ) == 0 and self.step > 0:
                    full_path = self.save_path + os.sep + str(self.step)
                    os.makedirs(full_path)
                    print("Saving last model in " + full_path +
                          ", DO NOT INTERRUPT")
                    torch.save(self.encoder,
                               os.path.join(full_path, "encoder.pt"),
                               pickle_module=dill)
                    torch.save(self.latent_compressor,
                               os.path.join(full_path, "latent_compressor.pt"),
                               pickle_module=dill)
                    torch.save(self.latent_decompressor,
                               os.path.join(full_path,
                                            "latent_decompressor.pt"),
                               pickle_module=dill)
                    torch.save(self.decoder,
                               os.path.join(full_path, "decoder.pt"),
                               pickle_module=dill)
                    torch.save(self.generator,
                               os.path.join(full_path, "generator.pt"),
                               pickle_module=dill)
                    if config["train"]["aae"]:
                        torch.save(self.discriminator,
                                   os.path.join(full_path, "discriminator.pt"),
                                   pickle_module=dill)
                    print("Model saved")

                ########
                # TEST #
                ########
                if (self.step % config["train"]["after_steps_make_songs"]) == 0 and config["train"]["make_songs"] \
                        and self.step > 0:
                    print("Making songs")
                    self.encoder.eval()
                    self.latent_compressor.eval()
                    self.latent_decompressor.eval()
                    self.decoder.eval()
                    self.generator.eval()

                    self.tester = Tester(self.encoder, self.latent_compressor,
                                         self.latent_decompressor,
                                         self.decoder, self.generator)

                    # RECONSTRUCTION
                    note_manager = NoteRepresentationManager()
                    to_reconstruct = second_batch
                    with torch.no_grad():
                        original, reconstructed, acc = self.tester.reconstruct(
                            to_reconstruct, note_manager)
                    prefix = "epoch_" + str(self.epoch) + "_mb_" + str(song_it)
                    self.logger.log_songs(os.path.join(wandb.run.dir, prefix),
                                          [original, reconstructed],
                                          ["original", "reconstructed"],
                                          "validation reconstruction example")
                    self.logger.log_reconstruction_accuracy(acc)

                    if config["train"]["aae"]:
                        # GENERATION
                        with torch.no_grad():
                            generated = self.tester.generate(
                                note_manager)  # generation
                        self.logger.log_songs(
                            os.path.join(wandb.run.dir, prefix), [generated],
                            ["generated"], "generated")

                        # INTERPOLATION
                        with torch.no_grad():
                            first, interpolation, second = self.tester.interpolation(
                                note_manager, first_batch, second_batch)

                        self.logger.log_songs(
                            os.path.join(wandb.run.dir, prefix),
                            [first, interpolation, second],
                            ["first", "interpolation", "second"],
                            "interpolation")
                    # end test
                    self.encoder.train()
                    self.latent_compressor.train()
                    self.latent_decompressor.train()
                    self.decoder.train()
                    self.generator.train()

                self.step += 1