示例#1
0
    def __init__(self, rb=9):
        self.gen_x = Generator(rb).to(device)
        self.gen_y = Generator(rb).to(device)
        self.dis_x = Discriminator().to(device)
        self.dis_y = Discriminator().to(device)
        self.fake_x_buffer = ImageBuffer()
        self.fake_y_buffer = ImageBuffer()
        self.crit = nn.MSELoss()
        self.l1 = torch.nn.L1Loss()
        self.optimizer_gen = torch.optim.Adam(list(self.gen_x.parameters()) +
                                              list(self.gen_y.parameters()),
                                              lr=lr,
                                              betas=betas)
        self.optimizer_dis = torch.optim.Adam(list(self.dis_x.parameters()) +
                                              list(self.dis_y.parameters()),
                                              lr=lr,
                                              betas=betas)
        self.scaler_gen = torch.cuda.amp.GradScaler()
        self.scaler_dis = torch.cuda.amp.GradScaler()
        self.lr_dis = None
        self.lr_gen = None

        self.gen_y.apply(self.init_weights)
        self.gen_x.apply(self.init_weights)
        self.dis_x.apply(self.init_weights)
        self.dis_y.apply(self.init_weights)
    def __init__(self,
                 batch_size=64,
                 noise_vector_size=100,
                 num_epochs=1,
                 lr=0.0002,
                 beta1=0.5):
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        self.data_provider = Data_Provider(batch_size)
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.netG = Generator(noise_vector_size,
                              self.data_provider.num_ingredients).to(
                                  self.device)
        self.netD = Discriminator(self.data_provider.num_ingredients).to(
            self.device)

        self.criterion = nn.BCELoss()
        self.fixed_noise = torch.randn(batch_size,
                                       noise_vector_size,
                                       device=self.device)
        self.noise_vector_size = noise_vector_size
        self.real_label = 1
        self.fake_label = 0

        self.optimizerD = optim.Adam(self.netD.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))

        self.recipe_list = []
示例#3
0
 def __init__(self):
     self.netG = Generator().to(device)
     self.netD = Discriminator().to(device)
     self.netG.apply(self.weights_init)
     self.netD.apply(self.weights_init)
     self.fixed_noise = torch.randn(16, nz, 1, 1, device=device)
     self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=betas)
     self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=betas)
示例#4
0
class CycleGan:
    def __init__(self, rb=9):
        self.gen_x = Generator(rb).to(device)
        self.gen_y = Generator(rb).to(device)
        self.dis_x = Discriminator().to(device)
        self.dis_y = Discriminator().to(device)
        self.fake_x_buffer = ImageBuffer()
        self.fake_y_buffer = ImageBuffer()
        self.crit = nn.MSELoss()
        self.l1 = torch.nn.L1Loss()
        self.optimizer_gen = torch.optim.Adam(list(self.gen_x.parameters()) +
                                              list(self.gen_y.parameters()),
                                              lr=lr,
                                              betas=betas)
        self.optimizer_dis = torch.optim.Adam(list(self.dis_x.parameters()) +
                                              list(self.dis_y.parameters()),
                                              lr=lr,
                                              betas=betas)
        self.scaler_gen = torch.cuda.amp.GradScaler()
        self.scaler_dis = torch.cuda.amp.GradScaler()
        self.lr_dis = None
        self.lr_gen = None

        self.gen_y.apply(self.init_weights)
        self.gen_x.apply(self.init_weights)
        self.dis_x.apply(self.init_weights)
        self.dis_y.apply(self.init_weights)

    # Initiate weights with normal distribution
    def init_weights(self, m):
        if type(m) == torch.nn.Conv2d:
            torch.nn.init.normal_(m.weight, std=0.02, mean=0.0)

    # Toggle gradient tracking on discriminators
    def grad_toggle(self, grad):
        for param in self.dis_x.parameters():
            param.requires_grad = grad
        for param in self.dis_y.parameters():
            param.requires_grad = grad

    # Generator loss is gen_a vs 1's
    def loss_gen(self, result):
        return self.crit(result, torch.ones_like(result).to(device))

    # Dicriminator loss is gen_a vs a
    def loss_dis(self, real, fake):
        loss1 = self.crit(real, torch.ones_like(real).to(device))
        loss2 = self.crit(fake, torch.zeros_like(fake).to(device))
        return (loss1 + loss2) * .5

    # Cyclic loss is a vs gen_a(gen_b(a))
    def loss_cyclic(self, cycled, real):
        loss = self.l1(cycled, real) * lamb
        return loss

    # Identity loss a vs gen_a(a)
    def loss_identity(self, ident, real):
        loss = self.l1(ident, real) * lamb * ident_weight
        return loss

    # Return a the generated and cycled image
    def test_photo(self, image, y_in=False):
        if y_in == False:
            with torch.no_grad():
                fake_y = self.gen_y(image)
                cycled = self.gen_x(fake_y)
                return (fake_y, cycled)
        else:
            with torch.no_grad():
                fake_x = self.gen_x(image)
                cycled = self.gen_y(fake_x)
                return (fake_x, cycled)

    def step(self, x, y, step, total_step, log=False):
        self.optimizer_gen.zero_grad()
        self.grad_toggle(False)

        # Finding loss of the generators
        with torch.cuda.amp.autocast():
            # input -> target
            fake_y = self.gen_y(x)
            output_fake_y = self.dis_y(fake_y)
            loss_fake_y = self.loss_gen(output_fake_y)

            # target -> input
            fake_x = self.gen_x(y)
            output_fake_x = self.dis_x(fake_x)
            loss_fake_x = self.loss_gen(output_fake_x)

            # cycled
            cycled_y = self.gen_y(fake_x)
            loss_cycled_y = self.loss_cyclic(cycled_y, y)

            cycled_x = self.gen_x(fake_y)
            loss_cycled_x = self.loss_cyclic(cycled_x, x)

            # identities
            ident_x = self.gen_x(x)
            ident_y = self.gen_y(y)

            loss_ident_y = self.loss_identity(ident_y, y)
            loss_ident_x = self.loss_identity(ident_x, x)

            loss_g = loss_fake_y + loss_cycled_x + loss_fake_x + loss_cycled_y + loss_ident_y + loss_ident_x

        self.scaler_gen.scale(loss_g).backward()
        self.scaler_gen.step(self.optimizer_gen)
        self.scaler_gen.update()
        self.grad_toggle(True)
        self.optimizer_dis.zero_grad()

        # Finding loss of the discriminator
        with torch.cuda.amp.autocast():
            temp = self.fake_y_buffer.grab(fake_y.detach())
            dis_fake_y = self.dis_y(temp.detach())
            dis_y = self.dis_y(y)
            loss_Dy = self.loss_dis(dis_y, dis_fake_y)

            temp = self.fake_x_buffer.grab(fake_x.detach())
            dis_fake_x = self.dis_x(temp.detach())
            dis_x = self.dis_x(x)
            loss_Dx = self.loss_dis(dis_x, dis_fake_x)
            loss = loss_Dx + loss_Dy

        self.scaler_dis.scale(loss).backward()
        self.scaler_dis.step(self.optimizer_dis)
        self.scaler_dis.update()

        if log:
            print(
                "Step:",
                str(step) + "/" + str(total_step),
            )
            print("loss fake_y: ", loss_fake_y.item(), "loss cycled_x:",
                  loss_cycled_x.item(), "loss ident_x:", loss_ident_x.item())
            print("loss fake_x: ", loss_fake_x.item(), "loss cycled_y:",
                  loss_cycled_y.item(), "loss ident_y:", loss_ident_y.item())
            print("loss Dx:", loss_Dx.item(), "loss Dy:", loss_Dy.item())

    # Checkpoint the training
    def checkpoint(self, epoch):
        path = checkpoint_dir + check_name
        torch.save(
            {
                'gen_x': self.gen_x.state_dict(),
                'gen_y': self.gen_y.state_dict(),
                'dis_x': self.dis_x.state_dict(),
                'dis_y': self.dis_y.state_dict(),
                'optimizer_gen': self.optimizer_gen.state_dict(),
                'optimizer_dis': self.optimizer_dis.state_dict(),
                'scaler_gen': self.scaler_gen.state_dict(),
                'scaler_dis': self.scaler_dis.state_dict(),
                'lr_gen': self.lr_gen,
                'lr_dis': self.lr_dis,
                'epoch': epoch,
            }, path)

    # Load latest checkpoint
    def loadcheck(self):
        path = checkpoint_dir + check_name
        check = torch.load(path)
        self.gen_x.load_state_dict(check['gen_x'])
        self.gen_y.load_state_dict(check['gen_y'])
        self.dis_x.load_state_dict(check['dis_x'])
        self.dis_y.load_state_dict(check['dis_y'])
        self.optimizer_gen.load_state_dict(check['optimizer_gen'])
        self.optimizer_dis.load_state_dict(check['optimizer_dis'])
        self.scaler_gen.load_state_dict(check['scaler_gen'])
        self.scaler_dis.load_state_dict(check['scaler_dis'])
        try:
            self.self.lr_gen = check['lr_gen']
            self.self.lr_dis = check['lr_dis']
        except:
            pass
        return check['epoch']

    # Start linearly scaling optimizer at scale_epoch
    def scale_optimizers(self, current_epoch, total_epochs, scale_epoch):
        if current_epoch < scale_epoch:
            pass
        else:
            if self.lr_dis == None:
                self.lr_dis = self.optimizer_dis.param_groups[0]["lr"]
                self.lr_gen = self.optimizer_gen.param_groups[0]["lr"]
            scale = 1 - (current_epoch - scale_epoch) / (total_epochs -
                                                         scale_epoch)
            self.optimizer_dis.param_groups[0]["lr"] = self.lr_dis * scale
            self.optimizer_gen.param_groups[0]["lr"] = self.lr_gen * scale
示例#5
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        for k, v in vars(args).items():
            setattr(self, k, v)
        self.args = args
        self.data_path = './origin_data/' + self.dataset + '/'

        self.train_tasks = json.load(open(self.data_path + 'train_tasks.json'))
        self.rel2id = json.load(open(self.data_path + 'relation2ids'))

        # Generate the relation matrix according to word embeddings and TFIDF
        if self.generate_text_embedding:
            if self.dataset == "NELL":
                NELL_text_embedding(self.args)
            else:
                raise AttributeError("wrong dataset name!")

        rela_matrix = np.load(self.data_path + 'rela_matrix.npz')['relaM']
        print('##LOADING RELATION MATRIX##')
        self.rela_matrix = rela_matrix.astype('float32')

        self.ent2id = json.load(open(self.data_path + 'entity2id'))

        print('##LOADING CANDIDATES ENTITIES##')
        self.rel2candidates = json.load(
            open(self.data_path + 'rel2candidates_all.json'))

        # load answer dict
        self.e1rel_e2 = defaultdict(list)
        self.e1rel_e2 = json.load(open(self.data_path + 'e1rel_e2_all.json'))

        noises = Variable(torch.randn(self.test_sample, self.noise_dim)).cuda()
        self.test_noises = 0.1 * noises

        self.meta = not self.no_meta
        self.label_num = len(self.train_tasks.keys())

        self.rela2label = dict()
        rela_sorted = sorted(list(self.train_tasks.keys()))
        for i, rela in enumerate(rela_sorted):
            self.rela2label[rela] = int(i)

        print('##LOADING SYMBOL ID AND SYMBOL EMBEDDING')
        self.load_embed()
        self.num_symbols = len(self.symbol2id.keys()) - 1  #
        self.pad_id = self.num_symbols

        print('##DEFINE FEATURE EXTRACTOR')
        self.Extractor = Extractor(self.embed_dim,
                                   self.num_symbols,
                                   embed=self.symbol2vec)
        self.Extractor.cuda()
        self.Extractor.apply(weights_init)
        self.E_parameters = filter(lambda p: p.requires_grad,
                                   self.Extractor.parameters())
        self.optim_E = optim.Adam(self.E_parameters, lr=self.lr_E)
        #self.scheduler = optim.lr_scheduler.MultiStepLR(self.optim_E, milestones=[50000], gamma=0.5)

        print('##DEFINE GENERATOR')
        self.Generator = Generator(self.args)
        self.Generator.cuda()
        self.Generator.apply(weights_init)
        self.G_parameters = filter(lambda p: p.requires_grad,
                                   self.Generator.parameters())
        self.optim_G = optim.Adam(self.G_parameters,
                                  lr=self.lr_G,
                                  betas=(0.5, 0.9))
        self.scheduler_G = optim.lr_scheduler.MultiStepLR(self.optim_G,
                                                          milestones=[4000],
                                                          gamma=0.2)

        print('##DEFINE DISCRIMINATOR')
        self.Discriminator = Discriminator()
        self.Discriminator.cuda()
        self.Discriminator.apply(weights_init)
        self.D_parameters = filter(lambda p: p.requires_grad,
                                   self.Discriminator.parameters())
        self.optim_D = optim.Adam(self.D_parameters,
                                  lr=self.lr_D,
                                  betas=(0.5, 0.9))
        self.scheduler_D = optim.lr_scheduler.MultiStepLR(self.optim_D,
                                                          milestones=[20000],
                                                          gamma=0.2)

        self.num_ents = len(self.ent2id.keys())

        print('##BUILDING CONNECTION MATRIX')
        degrees = self.build_connection(max_=self.max_neighbor)
示例#6
0
class Trainer(object):
    def __init__(self, args):
        super(Trainer, self).__init__()
        for k, v in vars(args).items():
            setattr(self, k, v)
        self.args = args
        self.data_path = './origin_data/' + self.dataset + '/'

        self.train_tasks = json.load(open(self.data_path + 'train_tasks.json'))
        self.rel2id = json.load(open(self.data_path + 'relation2ids'))

        # Generate the relation matrix according to word embeddings and TFIDF
        if self.generate_text_embedding:
            if self.dataset == "NELL":
                NELL_text_embedding(self.args)
            else:
                raise AttributeError("wrong dataset name!")

        rela_matrix = np.load(self.data_path + 'rela_matrix.npz')['relaM']
        print('##LOADING RELATION MATRIX##')
        self.rela_matrix = rela_matrix.astype('float32')

        self.ent2id = json.load(open(self.data_path + 'entity2id'))

        print('##LOADING CANDIDATES ENTITIES##')
        self.rel2candidates = json.load(
            open(self.data_path + 'rel2candidates_all.json'))

        # load answer dict
        self.e1rel_e2 = defaultdict(list)
        self.e1rel_e2 = json.load(open(self.data_path + 'e1rel_e2_all.json'))

        noises = Variable(torch.randn(self.test_sample, self.noise_dim)).cuda()
        self.test_noises = 0.1 * noises

        self.meta = not self.no_meta
        self.label_num = len(self.train_tasks.keys())

        self.rela2label = dict()
        rela_sorted = sorted(list(self.train_tasks.keys()))
        for i, rela in enumerate(rela_sorted):
            self.rela2label[rela] = int(i)

        print('##LOADING SYMBOL ID AND SYMBOL EMBEDDING')
        self.load_embed()
        self.num_symbols = len(self.symbol2id.keys()) - 1  #
        self.pad_id = self.num_symbols

        print('##DEFINE FEATURE EXTRACTOR')
        self.Extractor = Extractor(self.embed_dim,
                                   self.num_symbols,
                                   embed=self.symbol2vec)
        self.Extractor.cuda()
        self.Extractor.apply(weights_init)
        self.E_parameters = filter(lambda p: p.requires_grad,
                                   self.Extractor.parameters())
        self.optim_E = optim.Adam(self.E_parameters, lr=self.lr_E)
        #self.scheduler = optim.lr_scheduler.MultiStepLR(self.optim_E, milestones=[50000], gamma=0.5)

        print('##DEFINE GENERATOR')
        self.Generator = Generator(self.args)
        self.Generator.cuda()
        self.Generator.apply(weights_init)
        self.G_parameters = filter(lambda p: p.requires_grad,
                                   self.Generator.parameters())
        self.optim_G = optim.Adam(self.G_parameters,
                                  lr=self.lr_G,
                                  betas=(0.5, 0.9))
        self.scheduler_G = optim.lr_scheduler.MultiStepLR(self.optim_G,
                                                          milestones=[4000],
                                                          gamma=0.2)

        print('##DEFINE DISCRIMINATOR')
        self.Discriminator = Discriminator()
        self.Discriminator.cuda()
        self.Discriminator.apply(weights_init)
        self.D_parameters = filter(lambda p: p.requires_grad,
                                   self.Discriminator.parameters())
        self.optim_D = optim.Adam(self.D_parameters,
                                  lr=self.lr_D,
                                  betas=(0.5, 0.9))
        self.scheduler_D = optim.lr_scheduler.MultiStepLR(self.optim_D,
                                                          milestones=[20000],
                                                          gamma=0.2)

        self.num_ents = len(self.ent2id.keys())

        print('##BUILDING CONNECTION MATRIX')
        degrees = self.build_connection(max_=self.max_neighbor)

    def load_symbol2id(self):
        symbol_id = {}
        i = 0
        for key in self.rel2id.keys():
            if key not in ['', 'OOV']:
                symbol_id[key] = i
                i += 1

        for key in self.ent2id.keys():
            if key not in ['', 'OOV']:
                symbol_id[key] = i
                i += 1
        symbol_id['PAD'] = i
        self.symbol2id = symbol_id
        self.symbol2vec = None

    def load_embed(self):

        symbol_id = {}

        print('##LOADING PRE-TRAINED EMBEDDING')
        if self.embed_model in ['DistMult', 'TransE', 'ComplEx', 'RESCAL']:
            embed_all = np.load(self.data_path + self.embed_model +
                                '_embed.npz')
            ent_embed = embed_all['eM']
            rel_embed = embed_all['rM']

            if self.embed_model == 'ComplEx':
                # normalize the complex embeddings
                ent_mean = np.mean(ent_embed, axis=1, keepdims=True)
                ent_std = np.std(ent_embed, axis=1, keepdims=True)
                rel_mean = np.mean(rel_embed, axis=1, keepdims=True)
                rel_std = np.std(rel_embed, axis=1, keepdims=True)
                eps = 1e-3
                ent_embed = (ent_embed - ent_mean) / (ent_std + eps)
                rel_embed = (rel_embed - rel_mean) / (rel_std + eps)

            print(
                '    ent_embed shape is {}, the number of entity is {}'.format(
                    ent_embed.shape, len(self.ent2id.keys())))
            print('    rel_embed shape is {}, the number of relation is {}'.
                  format(rel_embed.shape, len(self.rel2id.keys())))

            i = 0
            embeddings = []
            for key in self.rel2id.keys():
                if key not in ['', 'OOV']:
                    symbol_id[key] = i
                    i += 1
                    embeddings.append(list(rel_embed[self.rel2id[key], :]))

            for key in self.ent2id.keys():
                if key not in ['', 'OOV']:
                    symbol_id[key] = i
                    i += 1
                    embeddings.append(list(ent_embed[self.ent2id[key], :]))

            symbol_id['PAD'] = i
            embeddings.append(list(np.zeros((rel_embed.shape[1], ))))
            embeddings = np.array(embeddings)
            np.savez('origin_data/NELL/Embed_used/' + self.embed_model,
                     embeddings)
            json.dump(
                symbol_id,
                open('origin_data/NELL/Embed_used/' + self.embed_model + '2id',
                     'w'))

            self.symbol2id = symbol_id
            self.symbol2vec = embeddings

    #  build neighbor connection
    def build_connection(self, max_=100):

        self.connections = (np.ones(
            (self.num_ents, max_, 2)) * self.pad_id).astype(int)
        self.e1_rele2 = defaultdict(list)
        self.e1_degrees = defaultdict(int)
        with open(self.data_path + 'path_graph') as f:
            lines = f.readlines()
            for line in tqdm(lines):
                e1, rel, e2 = line.rstrip().split()
                self.e1_rele2[e1].append(
                    (self.symbol2id[rel], self.symbol2id[e2]))
                #self.e1_rele2[e2].append((self.symbol2id[rel+'_inv'], self.symbol2id[e1]))
                self.e1_rele2[e2].append(
                    (self.symbol2id[rel], self.symbol2id[e1]))

        degrees = {}
        for ent, id_ in self.ent2id.items():
            neighbors = self.e1_rele2[ent]
            if len(neighbors) > max_:
                neighbors = neighbors[:max_]
            # degrees.append(len(neighbors))
            degrees[ent] = len(neighbors)
            self.e1_degrees[id_] = len(neighbors)  # add one for self conn
            for idx, _ in enumerate(neighbors):
                self.connections[id_, idx, 0] = _[0]
                self.connections[id_, idx, 1] = _[1]
        # json.dump(degrees, open(self.dataset + '/degrees', 'w'))
        # assert 1==2

        return degrees

    def save_pretrain(self, path=None):
        if not path:
            path = self.save_path
        torch.save(self.Extractor.state_dict(), path + 'Extractor')

    def load_pretrain(self):
        self.Extractor.load_state_dict(torch.load(self.save_path +
                                                  'Extractor'))

    def save(self, path=None):
        if not path:
            path = self.save_path
        torch.save(self.Generator.state_dict(), path + 'Generator')
        torch.save(self.Discriminator.state_dict(), path + 'Discriminator')

    def load(self):
        self.Generator.load_state_dict(torch.load(self.save_path +
                                                  'Generator'))
        self.Discriminator.load_state_dict(
            torch.load(self.save_path + 'Discriminator'))

    def get_meta(self, left, right):
        left_connections = Variable(
            torch.LongTensor(
                np.stack([self.connections[_, :, :] for _ in left],
                         axis=0))).cuda()
        left_degrees = Variable(
            torch.FloatTensor([self.e1_degrees[_] for _ in left])).cuda()
        right_connections = Variable(
            torch.LongTensor(
                np.stack([self.connections[_, :, :] for _ in right],
                         axis=0))).cuda()
        right_degrees = Variable(
            torch.FloatTensor([self.e1_degrees[_] for _ in right])).cuda()
        return (left_connections, left_degrees, right_connections,
                right_degrees)

    def pretrain_Extractor(self):
        print('\n##PRETRAINING FEATURE EXTRACTOR ....')
        pretrain_losses = deque([], 100)

        i = 0
        for data in Extractor_generate(self.data_path,
                                       self.pretrain_batch_size,
                                       self.symbol2id, self.ent2id,
                                       self.e1rel_e2, self.pretrain_few,
                                       self.pretrain_subepoch):
            i += 1

            support, query, false, support_left, support_right, query_left, query_right, false_left, false_right = data

            support_meta = self.get_meta(support_left, support_right)
            query_meta = self.get_meta(query_left, query_right)
            false_meta = self.get_meta(false_left, false_right)

            support = Variable(torch.LongTensor(support)).cuda()
            query = Variable(torch.LongTensor(query)).cuda()
            false = Variable(torch.LongTensor(false)).cuda()

            query_ep, query_scores = self.Extractor(query, support, query_meta,
                                                    support_meta)
            false_ep, false_scores = self.Extractor(false, support, false_meta,
                                                    support_meta)

            margin_ = query_scores - false_scores
            pretrain_loss = F.relu(self.pretrain_margin - margin_).mean()

            self.optim_E.zero_grad()
            pretrain_loss.backward()
            #self.scheduler.step()
            pretrain_losses.append(pretrain_loss.item())

            if i % self.pretrain_loss_every == 0:
                print("Step: %d, Feature Extractor Pretraining loss: %.2f" %
                      (i, np.mean(pretrain_losses)))

            self.optim_E.step()

            if i > self.pretrain_times:
                break

        self.save_pretrain()
        print('SAVE FEATURE EXTRACTOR PRETRAINING MODEL!!!')

    def train(self):
        print('\n##START ADVERSARIAL TRAINING...')

        # Pretraining step to obtain reasonable real data embeddings
        if self.pretrain_feature_extractor:
            self.pretrain_Extractor()
            print('Finish Pretraining!\n')

        self.load_pretrain()

        self.centroid_matrix = torch.zeros(
            (len(self.train_tasks), self.ep_dim))
        self.centroid_matrix = self.centroid_matrix.cuda()

        for relname in self.train_tasks.keys():
            query, query_left, query_right, label_id = centroid_generate(
                self.data_path, relname, self.symbol2id, self.ent2id,
                self.train_tasks, self.rela2label)
            query_meta = self.get_meta(query_left, query_right)
            query = Variable(torch.LongTensor(query)).cuda()
            query_ep, _ = self.Extractor(query, query, query_meta, query_meta)
            self.centroid_matrix[label_id] = query_ep.data.mean(dim=0)
        self.centroid_matrix = Variable(self.centroid_matrix)

        best_hits10 = 0.0

        D_every = self.D_epoch * self.loss_every
        D_losses = deque([], D_every)
        D_real_losses, D_real_class_losses, D_fake_losses, D_fake_class_losses \
            = deque([], D_every), deque([], D_every), deque([], D_every), deque([], D_every)

        # loss_G_fake + loss_G_class + loss_VP
        G_every = self.G_epoch * self.loss_every
        G_losses = deque([], G_every)
        G_fake_losses, G_class_losses, G_VP_losses, G_real_class_losses \
            = deque([], G_every), deque([], G_every), deque([], G_every), deque([], G_every)

        G_data = train_generate_decription(self.data_path, self.G_batch_size,
                                           self.symbol2id, self.ent2id,
                                           self.e1rel_e2, self.rel2id,
                                           self.args, self.rela2label,
                                           self.rela_matrix)

        nets = [self.Generator, self.Discriminator]
        reset_grad(nets)

        for epoch in range(self.train_times):

            # train Discriminator
            self.Discriminator.train()
            self.Generator.eval()
            for _ in range(self.D_epoch):  # D_epoch = 5
                ### Discriminator real part
                D_descriptions, query, query_left, query_right, D_false, D_false_left, D_false_right, D_labels = G_data.__next__(
                )

                # real part
                query_meta = self.get_meta(query_left, query_right)
                query = Variable(torch.LongTensor(query)).cuda()
                D_real, _ = self.Extractor(query, query, query_meta,
                                           query_meta)

                # fake part
                noises = Variable(torch.randn(len(query),
                                              self.noise_dim)).cuda()
                D_descriptions = Variable(
                    torch.FloatTensor(D_descriptions)).cuda()
                D_fake = self.Generator(D_descriptions, noises)

                # neg part
                D_false_meta = self.get_meta(D_false_left, D_false_right)
                D_false = Variable(torch.LongTensor(D_false)).cuda()
                D_neg, _ = self.Extractor(D_false, D_false, D_false_meta,
                                          D_false_meta)

                # generate Discriminator part vector
                centroid_matrix_ = self.centroid_matrix  #gaussian_noise(self.centroid_matrix)
                _, D_real_decision, D_real_class = self.Discriminator(
                    D_real.detach(), centroid_matrix_)
                _, D_fake_decision, D_fake_class = self.Discriminator(
                    D_fake.detach(), centroid_matrix_)
                _, _, D_neg_class = self.Discriminator(D_neg.detach(),
                                                       self.centroid_matrix)

                # real adversarial training loss
                loss_D_real = -torch.mean(D_real_decision)

                # adversarial training loss
                loss_D_fake = torch.mean(D_fake_decision)

                # real classification loss
                D_real_scores = D_real_class[range(len(query)), D_labels]
                D_neg_scores = D_neg_class[range(len(query)), D_labels]
                D_margin_real = D_real_scores - D_neg_scores
                loss_rela_class = F.relu(self.pretrain_margin -
                                         D_margin_real).mean()

                # fake classification loss
                D_fake_scores = D_fake_class[range(len(query)), D_labels]
                D_margin_fake = D_fake_scores - D_neg_scores
                loss_fake_class = F.relu(self.pretrain_margin -
                                         D_margin_fake).mean()

                grad_penalty = calc_gradient_penalty(self.Discriminator,
                                                     D_real.data, D_fake.data,
                                                     len(query),
                                                     self.centroid_matrix)

                loss_D = loss_D_real + 0.5 * loss_rela_class + loss_D_fake + grad_penalty + 0.5 * loss_fake_class

                # D_real_losses, D_real_class_losses, D_fake_losses, D_fake_class_losses
                D_losses.append(loss_D.item())
                D_real_losses.append(loss_D_real.item())
                D_real_class_losses.append(loss_rela_class.item())
                D_fake_losses.append(loss_D_fake.item())
                D_fake_class_losses.append(loss_fake_class.item())

                loss_D.backward()
                self.scheduler_D.step()
                self.optim_D.step()
                reset_grad(nets)

            # train Generator
            self.Discriminator.eval()
            self.Generator.train()
            for _ in range(self.G_epoch):  # G_epoch = 1

                G_descriptions, query, query_left, query_right, G_false, G_false_left, G_false_right, G_labels = G_data.__next__(
                )

                # G sample
                noises = Variable(torch.randn(len(query),
                                              self.noise_dim)).cuda()
                G_descriptions = Variable(
                    torch.FloatTensor(G_descriptions)).cuda()
                G_sample = self.Generator(G_descriptions, noises)  # to train G

                # real data
                query_meta = self.get_meta(query_left, query_right)
                query = Variable(torch.LongTensor(query)).cuda()
                G_real, _ = self.Extractor(query, query, query_meta,
                                           query_meta)

                # This negative for classification loss
                G_false_meta = self.get_meta(G_false_left, G_false_right)
                G_false = Variable(torch.LongTensor(G_false)).cuda()
                G_neg, _ = self.Extractor(
                    G_false, G_false, G_false_meta,
                    G_false_meta)  # just use Extractor to generate ep vector

                # generate Discriminator part vector
                centroid_matrix_ = self.centroid_matrix
                _, G_decision, G_class = self.Discriminator(
                    G_sample, centroid_matrix_)
                _, _, G_real_class = self.Discriminator(
                    G_real.detach(), centroid_matrix_)
                _, _, G_neg_class = self.Discriminator(G_neg.detach(),
                                                       centroid_matrix_)

                # adversarial training loss
                loss_G_fake = -torch.mean(G_decision)

                # G sample classification loss
                G_scores = G_class[range(len(query)), G_labels]
                G_neg_scores = G_neg_class[range(len(query)), G_labels]
                G_margin_ = G_scores - G_neg_scores
                loss_G_class = F.relu(self.pretrain_margin - G_margin_).mean()

                # real classification loss
                G_real_scores = G_real_class[range(len(query)), G_labels]
                G_margin_real = G_real_scores - G_neg_scores
                loss_rela_class_ = F.relu(self.pretrain_margin -
                                          G_margin_real).mean()

                # Visual Pivot Regularization
                count = 0
                loss_VP = Variable(torch.Tensor([0.0])).cuda()
                for i in range(len(self.train_tasks.keys())):
                    sample_idx = (np.array(G_labels) == i).nonzero()[0]
                    count += len(sample_idx)
                    if len(sample_idx) == 0:
                        loss_VP += 0.0
                    else:
                        G_sample_cls = G_sample[sample_idx, :]
                        loss_VP += (
                            G_sample_cls.mean(dim=0) -
                            self.centroid_matrix[i]).pow(2).sum().sqrt()
                assert count == len(query)
                loss_VP *= float(1.0 / self.gan_batch_rela)

                # ||W||_2 regularization
                reg_loss = Variable(torch.Tensor([0.0])).cuda()
                if self.REG_W != 0:
                    for name, p in self.Generator.named_parameters():
                        if 'weight' in name:
                            reg_loss += p.pow(2).sum()
                    reg_loss.mul_(self.REG_W)

                # ||W_z||21 regularization, make W_z sparse
                reg_Wz_loss = Variable(torch.Tensor([0.0])).cuda()
                if self.REG_Wz != 0:
                    Wz = self.Generator.fc1.weight
                    reg_Wz_loss = Wz.pow(2).sum(dim=0).sqrt().sum().mul(
                        self.REG_Wz)

                # Generator loss function
                loss_G = loss_G_fake + loss_G_class + 3.0 * loss_VP  # + reg_Wz_loss + reg_loss

                # G_fake_losses, G_class_losses, G_VP_losses
                G_losses.append(loss_G.item())
                G_fake_losses.append(loss_G_fake.item())
                G_class_losses.append(loss_G_class.item())
                G_real_class_losses.append(loss_rela_class_.item())
                G_VP_losses.append(loss_VP.item())

                loss_G.backward()
                self.scheduler_G.step()
                self.optim_G.step()
                reset_grad(nets)

            if epoch % self.loss_every == 0 and epoch != 0:
                D_screen = [
                    np.mean(D_real_losses),
                    np.mean(D_real_class_losses),
                    np.mean(D_fake_losses),
                    np.mean(D_fake_class_losses)
                ]
                G_screen = [
                    np.mean(G_fake_losses),
                    np.mean(G_class_losses),
                    np.mean(G_real_class_losses),
                    np.mean(G_VP_losses)
                ]
                print("Epoch: %d, D_loss: %.2f [%.2f, %.2f, %.2f, %.2f], G_loss: %.2f [%.2f, %.2f, %.2f, %.2f]" \
                    % (epoch, np.mean(D_losses), D_screen[0], D_screen[1], D_screen[2], D_screen[3], np.mean(G_losses), G_screen[0], G_screen[1], G_screen[2], G_screen[3]))
                #print("    D_lr", self.scheduler_D.get_lr()[0], "G_lr", self.scheduler_G.get_lr()[0])

        hits10_test, hits5_test, mrr_test = self.eval(mode='test',
                                                      meta=self.meta)
        self.save()

    def eval_pretrain(self, mode='test', meta=True):
        self.Extractor.eval()

        symbol2id = self.symbol2id

        test_candidates = json.load(
            open("./origin_data/NELL/" + mode + "_candidates.json"))

        print('##EVALUATING ON %s DATA' % mode.upper())
        if mode == 'dev':
            test_tasks = json.load(open(self.data_path + 'dev_tasks.json'))
        elif mode == 'test':
            test_tasks = json.load(open(self.data_path + 'test_tasks.json'))
        elif mode == 'train':
            test_tasks = json.load(open(self.data_path + 'train_tasks.json'))
        else:
            raise AttributeError("Wrong dataset type!")

        rela2label = dict()
        rela_sorted = sorted(list(test_tasks.keys()))
        for i, rela in enumerate(rela_sorted):
            rela2label[rela] = int(i)

        centroid_matrix = torch.zeros((len(test_tasks), self.ep_dim))
        centroid_matrix = centroid_matrix.cuda()
        for relname in test_tasks.keys():
            query, query_left, query_right, label_id = centroid_generate(
                self.data_path, relname, self.symbol2id, self.ent2id,
                test_tasks, rela2label)
            query_meta = self.get_meta(query_left, query_right)
            query = Variable(torch.LongTensor(query)).cuda()
            query_ep, _ = self.Extractor(query, query, query_meta, query_meta)
            centroid_matrix[label_id] = query_ep.data.mean(dim=0)
        centroid_matrix = Variable(centroid_matrix)

        hits10 = []
        hits5 = []
        hits1 = []
        mrr = []

        for query_ in test_candidates.keys():

            relation_vec = centroid_matrix[rela2label[query_]]
            relation_vec = relation_vec.view(1, -1)
            relation_vec = relation_vec.data.cpu().numpy()

            hits10_ = []
            hits5_ = []
            hits1_ = []
            mrr_ = []

            for e1_rel, tail_candidates in test_candidates[query_].items():
                head, rela, _ = e1_rel.split('\t')

                true = tail_candidates[0]  #
                query_pairs = []
                query_pairs.append([symbol2id[head], symbol2id[true]])

                if meta:
                    query_left = []
                    query_right = []
                    query_left.append(self.ent2id[head])
                    query_right.append(self.ent2id[true])

                for tail in tail_candidates[1:]:
                    query_pairs.append([symbol2id[head], symbol2id[tail]])
                    if meta:
                        query_left.append(self.ent2id[head])
                        query_right.append(self.ent2id[tail])

                query = Variable(torch.LongTensor(query_pairs)).cuda()

                if meta:
                    query_meta = self.get_meta(query_left, query_right)
                    candidate_vecs, _ = self.Extractor(query, query,
                                                       query_meta, query_meta)
                    candidate_vecs.detach()
                    candidate_vecs = candidate_vecs.data.cpu().numpy()

                    # dot product
                    #scores = candidate_vecs.dot(relation_vec.transpose())
                    #scores = scores.mean(axis=1)

                    # cosine similarity
                    scores = cosine_similarity(candidate_vecs, relation_vec)
                    scores = scores.mean(axis=1)

                    # Euclidean distance
                    #scores = np.power(candidate_vecs - relation_vec, 2)
                    #scores = scores.sum(axis=1)
                    #scores = np.sqrt(scores)
                    #scores = -scores

                    #print scores.shape

                    assert scores.shape == (len(query_pairs), )
                sort = list(np.argsort(scores))[::-1]
                rank = sort.index(0) + 1
                if rank <= 10:
                    hits10.append(1.0)
                    hits10_.append(1.0)
                else:
                    hits10.append(0.0)
                    hits10_.append(0.0)
                if rank <= 5:
                    hits5.append(1.0)
                    hits5_.append(1.0)
                else:
                    hits5.append(0.0)
                    hits5_.append(0.0)
                if rank <= 1:
                    hits1.append(1.0)
                    hits1_.append(1.0)
                else:
                    hits1.append(0.0)
                    hits1_.append(0.0)
                mrr.append(1.0 / rank)
                mrr_.append(1.0 / rank)

            print('{} Hits10:{:.3f}, Hits5:{:.3f}, Hits1:{:.3f} MRR:{:.3f}'.
                  format(mode + query_, np.mean(hits10_), np.mean(hits5_),
                         np.mean(hits1_), np.mean(mrr_)))
            #print('Number of candidates: {}, number of text examples {}'.format(len(candidates), len(hits10_)))

        print('############   ' + mode + '    #############')
        print('HITS10: {:.3f}'.format(np.mean(hits10)))
        print('HITS5: {:.3f}'.format(np.mean(hits5)))
        print('HITS1: {:.3f}'.format(np.mean(hits1)))
        print('MAP: {:.3f}'.format(np.mean(mrr)))
        print('###################################')

        return np.mean(hits10), np.mean(hits5), np.mean(mrr)

    def eval(self, mode='dev', meta=True):
        self.Generator.eval()
        self.Discriminator.eval()
        self.Extractor.eval()
        symbol2id = self.symbol2id

        #logging.info('EVALUATING ON %s DATA' % mode.upper())
        print('##EVALUATING ON %s DATA' % mode.upper())
        if mode == 'dev':
            test_tasks = json.load(open(self.data_path + 'dev_tasks.json'))
        elif mode == 'test':
            test_tasks = json.load(open(self.data_path + 'test_tasks.json'))
        elif mode == 'train':
            test_tasks = json.load(open(self.data_path + 'train_tasks.json'))
        else:
            raise AttributeError("Wrong dataset type!")

        test_candidates = json.load(
            open("./origin_data/NELL/" + mode + "_candidates.json"))

        hits10 = []
        hits5 = []
        hits1 = []
        mrr = []

        for query_ in test_candidates.keys():

            hits10_ = []
            hits5_ = []
            hits1_ = []
            mrr_ = []

            description = self.rela_matrix[self.rel2id[query_]]
            description = np.expand_dims(description, axis=0)
            descriptions = np.tile(description, (self.test_sample, 1))
            descriptions = Variable(torch.FloatTensor(descriptions)).cuda()
            relation_vecs = self.Generator(descriptions, self.test_noises)
            relation_vecs = relation_vecs.data.cpu().numpy()

            for e1_rel, tail_candidates in test_candidates[query_].items():
                head, rela, _ = e1_rel.split('\t')

                true = tail_candidates[0]
                query_pairs = []
                query_pairs.append([symbol2id[head], symbol2id[true]])

                if meta:
                    query_left = []
                    query_right = []
                    query_left.append(self.ent2id[head])
                    query_right.append(self.ent2id[true])

                for tail in tail_candidates[1:]:
                    query_pairs.append([symbol2id[head], symbol2id[tail]])
                    if meta:
                        query_left.append(self.ent2id[head])
                        query_right.append(self.ent2id[tail])

                query = Variable(torch.LongTensor(query_pairs)).cuda()

                if meta:
                    query_meta = self.get_meta(query_left, query_right)
                    candidate_vecs, _ = self.Extractor(query, query,
                                                       query_meta, query_meta)

                    candidate_vecs.detach()
                    candidate_vecs = candidate_vecs.data.cpu().numpy()

                    # dot product
                    #scores = candidate_vecs.dot(relation_vecs.transpose())

                    # cosine similarity
                    scores = cosine_similarity(candidate_vecs, relation_vecs)

                    scores = scores.mean(axis=1)

                    assert scores.shape == (len(query_pairs), )

                sort = list(np.argsort(scores))[::-1]
                rank = sort.index(0) + 1
                if rank <= 10:
                    hits10.append(1.0)
                    hits10_.append(1.0)
                else:
                    hits10.append(0.0)
                    hits10_.append(0.0)
                if rank <= 5:
                    hits5.append(1.0)
                    hits5_.append(1.0)
                else:
                    hits5.append(0.0)
                    hits5_.append(0.0)
                if rank <= 1:
                    hits1.append(1.0)
                    hits1_.append(1.0)
                else:
                    hits1.append(0.0)
                    hits1_.append(0.0)
                mrr.append(1.0 / rank)
                mrr_.append(1.0 / rank)

            #logging.critical('{} Hits10:{:.3f}, Hits5:{:.3f}, Hits1:{:.3f} MRR:{:.3f}'.format(query_, np.mean(hits10_), np.mean(hits5_), np.mean(hits1_), np.mean(mrr_)))
            #logging.info('Number of candidates: {}, number of text examples {}'.format(len(candidates), len(hits10_)))
            print('{} Hits10:{:.3f}, Hits5:{:.3f}, Hits1:{:.3f} MRR:{:.3f}'.
                  format(mode + query_, np.mean(hits10_), np.mean(hits5_),
                         np.mean(hits1_), np.mean(mrr_)))
            #print('Number of candidates: {}, number of text examples {}'.format(len(candidates), len(hits10_)))

        print('############   ' + mode + '    #############')
        print('HITS10: {:.3f}'.format(np.mean(hits10)))
        print('HITS5: {:.3f}'.format(np.mean(hits5)))
        print('HITS1: {:.3f}'.format(np.mean(hits1)))
        print('MAP: {:.3f}'.format(np.mean(mrr)))
        print('###################################')

        return np.mean(hits10), np.mean(hits5), np.mean(mrr)

    def test_(self):
        self.load()
        #logging.info('Pre-trained model loaded')
        print('Pre-trained model loaded')
        #self.eval(mode='dev', meta=self.meta)
        self.eval(mode='test', meta=self.meta)
示例#7
0
    **kwargs)

train_loader_B = mnist_triplet_train_loader
test_loader_B = mnist_triplet_test_loader
train_loader_S = mnist_mini_triplet_train_loader
test_loader_S = mnist_mini_triplet_test_loader

margin = 1.
embedding_net_B = MLP_Embedding()  # define network for big datasets
triplet_net_B = TripletNet(embedding_net_B)
embedding_net_S = MLP_Embedding()  # define network for small datasets
triplet_net_S = TripletNet(embedding_net_S)

layer_size = (256, 16)
G = Generator(layer_size)
D = Discriminator(layer_size)
# define hooks
# h_B = embedding_net_B.fc2.register_backward_hook(hook_B)
# h_S = embedding_net_S.fc2.register_backward_hook(hook_S)
if cuda:
    triplet_net_S.cuda()
    triplet_net_B.cuda()
    G.cuda()
    D.cuda()

loss_fn_S = TripletLoss(margin)
loss_fn_B = TripletLoss(margin)
lr = 1e-3
optim_B = optim.Adam(triplet_net_B.parameters(), lr=lr)
optim_S = optim.Adam(triplet_net_S.parameters(), lr=lr)
optim_G = optim.Adam(G.parameters(), lr=0.001)
示例#8
0
from Utils import Loader, Trainer, Plotter, Logger
from Networks import Discriminator, Generator
import torch
from torch import nn
from torch.autograd import Variable

# Load data -- hardcoded
data_path = "NEED PATH"
data_name = "NEED NAME"
batch_size = 250
loader = Loader(data_path, batch_size)
data_loader = loader.get_loader(loader.get_dataset())
num_batches = len(data_loader)

# Create newtork instances
discriminator = Discriminator()
discriminator.apply(Trainer.init_weights)
generator = Generator()
generator.apply(Trainer.init_weights)

# Create trainer and initialize network weights
device = "cpu"
if torch.cuda.is_available():
    print("cuda available")
    device = "cuda:0"
    discriminator.cuda()
    generator.cuda()

# Optimizers and loss function
net_trainer = Trainer(nn.BCELoss(), device)
net_trainer.create_optimizers(discriminator.parameters(),generator.parameters())
示例#9
0
    transforms.Normalize((0.5, ), (0.5, )),
])

to_pil_image = transforms.ToPILImage()

train_data = datasets.MNIST(root='../input/data',
                            train=True,
                            download=False,
                            transform=transform)
train_loader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=True)

from Networks import Generator
from Networks import Discriminator

generator = Generator(opt.latent_dim, opt.img_size, opt.channels).to(device)
discriminator = Discriminator(opt.img_size, opt.channels).to(device)

print(generator, discriminator)


def weights_init_normal(m):

    classname = m.__class__.__name__
    if classname.find("Conv") != -1:
        torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find("BatchNorm2d") != -1:
        torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
        torch.nn.init.constant_(m.bias.data, 0.0)


#Initialize Weights
示例#10
0
to_pil_image = transforms.ToPILImage()

train_data = datasets.MNIST(
    root='../input/data',
    train=True,
    download=False,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

from Networks import Generator
from Networks import Discriminator

generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)

print(discriminator)
print(generator)

# optimizers
optim_g = optim.Adam(generator.parameters(), lr=0.0002)
optim_d = optim.Adam(discriminator.parameters(), lr=0.0002)

# loss function
criterion = nn.BCELoss()

losses_g = [] # to store generator loss after each epoch
losses_d = [] # to store discriminator loss after each epoch
images = [] # to store images generatd by the generator
示例#11
0
class Module:
    def __init__(self,
                 batch_size=64,
                 noise_vector_size=100,
                 num_epochs=1,
                 lr=0.0002,
                 beta1=0.5):
        self.device = torch.device("cuda:0" if (
            torch.cuda.is_available()) else "cpu")
        self.data_provider = Data_Provider(batch_size)
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        self.netG = Generator(noise_vector_size,
                              self.data_provider.num_ingredients).to(
                                  self.device)
        self.netD = Discriminator(self.data_provider.num_ingredients).to(
            self.device)

        self.criterion = nn.BCELoss()
        self.fixed_noise = torch.randn(batch_size,
                                       noise_vector_size,
                                       device=self.device)
        self.noise_vector_size = noise_vector_size
        self.real_label = 1
        self.fake_label = 0

        self.optimizerD = optim.Adam(self.netD.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))
        self.optimizerG = optim.Adam(self.netG.parameters(),
                                     lr=lr,
                                     betas=(beta1, 0.999))

        self.recipe_list = []

    def generate_recipe(self):
        noise = torch.randn(1, self.noise_vector_size, device=self.device)
        fake = self.netG(noise).detach().numpy().flatten()
        return self.data_provider.ingredients[[
            idx for idx, v in enumerate(fake) if v
        ]]

    def train_module(self):
        G_losses = []
        D_losses = []
        iters = 0

        print("Starting Training Loop...")
        for epoch in range(self.num_epochs):
            for i, data in enumerate(self.data_provider.dataloader, 0):
                self.netD.zero_grad()
                real_cpu = data.to(self.device).float()
                output = self.netD(real_cpu).view(-1)
                label = torch.full((self.batch_size, ),
                                   self.real_label,
                                   device=self.device)
                errD_real = self.criterion(output, label)
                errD_real.backward()
                D_x = output.mean().item()

                noise = torch.randn(self.batch_size,
                                    self.noise_vector_size,
                                    device=self.device)
                fake = self.netG(noise)
                label.fill_(self.fake_label)
                output = self.netD(fake.detach()).view(-1)
                errD_fake = self.criterion(output, label)
                errD_fake.backward()
                D_G_z1 = output.mean().item()
                errD = errD_real + errD_fake
                self.optimizerD.step()

                self.netG.zero_grad()
                label.fill_(self.real_label)
                output = self.netD(fake).view(-1)
                errG = self.criterion(output, label)
                errG.backward()
                D_G_z2 = output.mean().item()
                self.optimizerG.step()

                if i % 50 == 0:
                    print(
                        '[%d/%d][%d/%d]\tLoss_D: %.4f\tLoss_G: %.4f\tD(x): %.4f\tD(G(z)): %.4f / %.4f'
                        % (epoch, self.num_epochs, i,
                           len(self.data_provider.dataloader), errD.item(),
                           errG.item(), D_x, D_G_z1, D_G_z2))

                G_losses.append(errG.item())
                D_losses.append(errD.item())

                if (iters % 500 == 0) or (
                    (epoch == self.num_epochs - 1) and
                    (i == len(self.data_provider.dataloader) - 1)):
                    with torch.no_grad():
                        fake = self.netG(self.fixed_noise).detach().cpu()
                    self.recipe_list.append(
                        vutils.make_grid(fake, padding=2, normalize=True))

                iters += 1

        self.netG.save()
        self.netD.save()
示例#12
0
class WGan():
    def __init__(self):
        self.netG = Generator().to(device)
        self.netD = Discriminator().to(device)
        self.netG.apply(self.weights_init)
        self.netD.apply(self.weights_init)
        self.fixed_noise = torch.randn(16, nz, 1, 1, device=device)
        self.optimizerD = optim.Adam(self.netD.parameters(), lr=lr, betas=betas)
        self.optimizerG = optim.Adam(self.netG.parameters(), lr=lr, betas=betas)


    def weights_init(self, m):
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    def checkpoint(self, epoch):
        path = checkpoint_dir + check_name
        torch.save({
            'netG': self.netG.state_dict(),
            'netD': self.netD.state_dict(),
            'optimizerD': self.optimizerD.state_dict(),
            'optimizerG': self.optimizerG.state_dict(),
            'fixed_noise': self.fixed_noise,
            # 'scaler_gen': self.scaler_gen.state_dict(),
            # 'scaler_dis': self.scaler_dis.state_dict(),
            'epoch': epoch,
            }, path)

    # Load latest checkpoint
    def loadcheck(self):
        path = checkpoint_dir + check_name
        check = torch.load(path)
        self.netG.load_state_dict(check['netG'])
        self.netD.load_state_dict(check['netD'])
        self.optimizerD.load_state_dict(check['optimizerD'])     
        self.optimizerG.load_state_dict(check['optimizerG'])
        self.fixed_noise = check['fixed_noise']
        # self.scaler_gen.load_state_dict(check['scaler_gen'])          
        # self.scaler_dis.load_state_dict(check['scaler_dis'])
        return check['epoch']    
    
    # Calculate gradient penalty described in the paper
    def gradient_penalty(self):
        alpha = torch.randn(batch_size, 1, 1, 1, device = device)
        interp = alpha * self.real_batch + ((1-alpha) * self.fake_batch.detach())
        interp.requires_grad_()

        model_interp = self.netD(interp)
        grads = torch.autograd.grad(outputs=model_interp, inputs=interp,
                                  grad_outputs=torch.ones(model_interp.size()).to(device),
                                  create_graph=True, retain_graph=True, only_inputs=True)[0]
        grads = torch.square(grads)
        grads = torch.sum(grads, dim = [1,2,3])
        grads = torch.sqrt(grads)
        grads = grads - 1
        grads = torch.square(grads)
        grad_pen = torch.mean(grads * lambda_gp)
        return grad_pen

    # Calculating the discriminator loss
    def dis_loss(self):
        loss_real = self.netD(self.real_batch)
        loss_real = -torch.mean(loss_real)

        loss_fake = self.netD(self.fake_batch.detach())
        loss_fake = torch.mean(loss_fake)
        
        grad_pen = self.gradient_penalty()
        
        d_loss = loss_fake + loss_real + grad_pen
        return d_loss

    def gen_loss(self):
            g_loss = self.netD(self.fake_batch)
            g_loss = -torch.mean(g_loss)
            return g_loss

    # One training step where the discriminator is trained everytime while the generator is trained every x batches
    def step(self, real_batch, epoch, i):
        self.real_batch = real_batch
        noise = torch.randn(batch_size, nz, 1, 1, device=device)
        self.fake_batch = self.netG(noise)

        self.optimizerD.zero_grad()
        d_loss = self.dis_loss()
        d_loss.backward()
        self.optimizerD.step()

        if i % d_ratio == 0:
            self.optimizerG.zero_grad()
            g_loss = self.gen_loss()
            g_loss.backward()
            self.optimizerG.step()
            if i % 20 == 0:
                self.log(g_loss.item(), d_loss.item(), epoch, i)

    # Simple logging to stdout and a plt figure
    def log(self, g_loss, d_loss, epoch, i):
        plt.close('all')
        print("\nEpoch:", epoch, "Iteration:", i * batch_size)
        print("Discriminator Loss:", d_loss, "G Loss:", g_loss)
        with torch.no_grad():
            guess = self.netG(self.fixed_noise)
            guess = guess.cpu()
            old_min = -1
            old_max = 1
            old_range = old_max - old_min
            new_range = 1 - 0
            guess = (((guess - old_min)*new_range)/ old_range) + 0
            guess = guess.permute(0,2,3,1)
            fig = plt.figure(figsize=(4,4))
            for i in range(16):
                plt.subplot(4, 4, i+1)
                plt.imshow(guess[i, :, :])
                plt.axis('off')
            
            plt.show(block=False)
            plt.pause(0.001)