Exemple #1
0
def test(model_path, submit_csv=hparams.submit_file, submit_file=hparams.submit_file, best_thresh=None):

    test_dataset = AudioData(data_csv=submit_csv, data_file=submit_file, ds_type='submit',
                        transform=transforms.Compose([
                            transforms.ToTensor(),
                        ]))

    test_loader = DataLoader(test_dataset, batch_size=hparams.batch_size,
                            shuffle=False, num_workers=2)


    discriminator = Discriminator().to(hparams.gpu_device)
    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids)
    checkpoint = torch.load(model_path, map_location=hparams.gpu_device)
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

    discriminator = discriminator.eval()
    # print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    print('Testing model on {0} examples. '.format(len(test_dataset)))

    with torch.no_grad():
        pred_logits_list = []
        labels_list = []
        img_names_list = []
        # for _ in range(hparams.repeat_infer):
        for (inp, labels, img_names) in tqdm(test_loader):
            inp = Variable(inp.float(), requires_grad=False)
            labels = Variable(labels.long(), requires_grad=False)

            inp = inp.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            inp = inp.view(-1, 1, 640, 64)
            inp = torch.cat([inp]*3, dim=1)

            pred_logits = discriminator(inp)

            pred_logits_list.append(pred_logits)
            labels_list.append(labels)
            img_names_list.append(img_names)

        pred_logits = torch.cat(pred_logits_list, dim=0)
        labels = torch.cat(labels_list, dim=0)

        pred_labels = pred_logits.max(1)[1]

        with open 
Exemple #2
0
class face_learner(object):
    def __init__(self, conf, inference=False):
        if conf.use_mobilfacenet:
            self.model = MobileFaceNet(conf.embedding_size).to(conf.device)
            print('MobileFaceNet model generated')
        else:
            self.model = Backbone(conf.net_depth, conf.drop_ratio,
                                  conf.net_mode).to(conf.device)
            self.growup = GrowUP().to(conf.device)
            self.discriminator = Discriminator().to(conf.device)
            print('{}_{} model generated'.format(conf.net_mode,
                                                 conf.net_depth))

        if not inference:

            self.milestones = conf.milestones
            self.loader, self.class_num = get_train_loader(conf)
            if conf.discriminator:
                self.child_loader, self.adult_loader = get_train_loader_d(conf)

            os.makedirs(conf.log_path, exist_ok=True)
            self.writer = SummaryWriter(conf.log_path)
            self.step = 0

            self.head = Arcface(embedding_size=conf.embedding_size,
                                classnum=self.class_num).to(conf.device)

            # Will not use anymore
            if conf.use_dp:
                self.model = nn.DataParallel(self.model)
                self.head = nn.DataParallel(self.head)

            print(self.class_num)
            print(conf)

            print('two model heads generated')

            paras_only_bn, paras_wo_bn = separate_bn_paras(self.model)

            if conf.use_mobilfacenet:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn[:-1],
                        'weight_decay': 4e-5
                    }, {
                        'params': [paras_wo_bn[-1]] + [self.head.kernel],
                        'weight_decay': 4e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            else:
                self.optimizer = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)
            if conf.discriminator:
                self.optimizer_g = optim.Adam(self.growup.parameters(),
                                              lr=1e-4,
                                              betas=(0.5, 0.999))
                self.optimizer_g2 = optim.Adam(self.growup.parameters(),
                                               lr=1e-4,
                                               betas=(0.5, 0.999))
                self.optimizer_d = optim.Adam(self.discriminator.parameters(),
                                              lr=1e-4,
                                              betas=(0.5, 0.999))
                self.optimizer2 = optim.SGD(
                    [{
                        'params': paras_wo_bn + [self.head.kernel],
                        'weight_decay': 5e-4
                    }, {
                        'params': paras_only_bn
                    }],
                    lr=conf.lr,
                    momentum=conf.momentum)

            if conf.finetune_model_path is not None:
                self.optimizer = optim.SGD([{
                    'params': paras_wo_bn,
                    'weight_decay': 5e-4
                }, {
                    'params': paras_only_bn
                }],
                                           lr=conf.lr,
                                           momentum=conf.momentum)
            print('optimizers generated')

            self.board_loss_every = len(self.loader) // 100
            self.evaluate_every = len(self.loader) // 2
            self.save_every = len(self.loader)

            dataset_root = "/home/nas1_userD/yonggyu/Face_dataset/face_emore"
            self.lfw = np.load(
                os.path.join(dataset_root,
                             "lfw_align_112_list.npy")).astype(np.float32)
            self.lfw_issame = np.load(
                os.path.join(dataset_root, "lfw_align_112_label.npy"))
            self.fgnetc = np.load(
                os.path.join(dataset_root,
                             "FGNET_new_align_list.npy")).astype(np.float32)
            self.fgnetc_issame = np.load(
                os.path.join(dataset_root, "FGNET_new_align_label.npy"))
        else:
            # Will not use anymore
            # self.model = nn.DataParallel(self.model)
            self.threshold = conf.threshold

    def board_val(self, db_name, accuracy, best_threshold, roc_curve_tensor,
                  negative_wrong, positive_wrong):
        self.writer.add_scalar('{}_accuracy'.format(db_name), accuracy,
                               self.step)
        self.writer.add_scalar('{}_best_threshold'.format(db_name),
                               best_threshold, self.step)
        self.writer.add_scalar('{}_negative_wrong'.format(db_name),
                               negative_wrong, self.step)
        self.writer.add_scalar('{}_positive_wrong'.format(db_name),
                               positive_wrong, self.step)
        self.writer.add_image('{}_roc_curve'.format(db_name), roc_curve_tensor,
                              self.step)

    def evaluate(self, conf, carray, issame, nrof_folds=10, tta=True):
        self.model.eval()
        self.growup.eval()
        self.discriminator.eval()
        idx = 0
        embeddings = np.zeros([len(carray), conf.embedding_size])
        with torch.no_grad():
            while idx + conf.batch_size <= len(carray):
                batch = torch.tensor(carray[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(
                        batch.to(conf.device)).cpu() + self.model(
                            fliped.to(conf.device)).cpu()
                    embeddings[idx:idx +
                               conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray):
                batch = torch.tensor(carray[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(
                        batch.to(conf.device)).cpu() + self.model(
                            fliped.to(conf.device)).cpu()
                    embeddings[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings[idx:] = self.model(batch.to(conf.device)).cpu()
        tpr, fpr, accuracy, best_thresholds, dist = evaluate_dist(
            embeddings, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = transforms.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor, dist

    def evaluate_child(self, conf, carray, issame, nrof_folds=10, tta=True):
        self.model.eval()
        self.growup.eval()
        self.discriminator.eval()
        idx = 0
        embeddings1 = np.zeros([len(carray) // 2, conf.embedding_size])
        embeddings2 = np.zeros([len(carray) // 2, conf.embedding_size])

        carray1 = carray[::2, ]
        carray2 = carray[1::2, ]

        with torch.no_grad():
            while idx + conf.batch_size <= len(carray1):
                batch = torch.tensor(carray1[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \
                                self.growup(self.model(fliped.to(conf.device))).cpu()
                    embeddings1[idx:idx +
                                conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings1[idx:idx + conf.batch_size] = self.growup(
                        self.model(batch.to(conf.device))).cpu()
                idx += conf.batch_size
            if idx < len(carray1):
                batch = torch.tensor(carray1[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.growup(self.model(batch.to(conf.device))).cpu() + \
                                self.growup(self.model(fliped.to(conf.device))).cpu()
                    embeddings1[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings1[idx:] = self.growup(
                        self.model(batch.to(conf.device))).cpu()

            while idx + conf.batch_size <= len(carray2):
                batch = torch.tensor(carray2[idx:idx + conf.batch_size])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)).cpu() + \
                                self.model(fliped.to(conf.device)).cpu()
                    embeddings2[idx:idx +
                                conf.batch_size] = l2_norm(emb_batch).cpu()
                else:
                    embeddings2[idx:idx + conf.batch_size] = self.model(
                        batch.to(conf.device)).cpu()
                idx += conf.batch_size
            if idx < len(carray2):
                batch = torch.tensor(carray2[idx:])
                if tta:
                    fliped = hflip_batch(batch)
                    emb_batch = self.model(batch.to(conf.device)).cpu() + \
                                self.model(fliped.to(conf.device)).cpu()
                    embeddings2[idx:] = l2_norm(emb_batch).cpu()
                else:
                    embeddings2[idx:] = self.model(batch.to(conf.device)).cpu()

        tpr, fpr, accuracy, best_thresholds = evaluate_child(
            embeddings1, embeddings2, issame, nrof_folds)
        buf = gen_plot(fpr, tpr)
        roc_curve = Image.open(buf)
        roc_curve_tensor = transforms.ToTensor()(roc_curve)
        return accuracy.mean(), best_thresholds.mean(), roc_curve_tensor

    def zero_grad(self):
        self.optimizer.zero_grad()
        self.optimizer_g.zero_grad()
        self.optimizer_d.zero_grad()

    def train(self, conf, epochs):
        self.model.train()
        running_loss = 0.
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()

            for imgs, labels, ages in tqdm(iter(self.loader)):

                self.optimizer.zero_grad()

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)

                embeddings = self.model(imgs)
                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()

                self.optimizer.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    running_loss = 0.

                # added wrong on evaluations
                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    # LFW evaluation
                    accuracy, best_threshold, roc_curve_tensor, dist = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    # NEGATIVE WRONG
                    wrong_list = np.where((self.lfw_issame == False)
                                          & (dist < best_threshold))[0]
                    negative_wrong = len(wrong_list)
                    # POSITIVE WRONG
                    wrong_list = np.where((self.lfw_issame == True)
                                          & (dist > best_threshold))[0]
                    positive_wrong = len(wrong_list)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor, negative_wrong,
                                   positive_wrong)

                    # FGNETC evaluation
                    accuracy2, best_threshold2, roc_curve_tensor2, dist2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    # NEGATIVE WRONG
                    wrong_list = np.where((self.fgnetc_issame == False)
                                          & (dist2 < best_threshold2))[0]
                    negative_wrong2 = len(wrong_list)
                    # POSITIVE WRONG
                    wrong_list = np.where((self.fgnetc_issame == True)
                                          & (dist2 > best_threshold2))[0]
                    positive_wrong2 = len(wrong_list)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2, negative_wrong2,
                                   positive_wrong2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    if conf.finetune_model_path is not None:
                        self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                            + '_' + str(conf.batch_size) + conf.model_name)
                    else:
                        self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                            + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        print('Horray!')

    def train_with_growup(self, conf, epochs):
        '''
        Our method
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()
                self.optimizer_g2.zero_grad()
                self.growup.train()

                c = (ages == 0)  # select children for enhancement

                embeddings = self.model(imgs)

                if sum(c) > 1:  # there might be no childern in loader's batch
                    embeddings_c = embeddings[c]
                    embeddings_a_hat = self.growup(embeddings_c)
                    embeddings[c] = embeddings_a_hat
                elif sum(c) == 1:
                    self.growup.eval()
                    embeddings_c = embeddings[c]
                    embeddings_a_hat = self.growup(embeddings_c)
                    embeddings[c] = embeddings_a_hat

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()
                self.optimizer_g2.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                self.growup.train()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                embeddings_a_hat = self.growup(embeddings_c)
                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_a_hat))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer_g.zero_grad()
                embeddings_c = self.model(imgs_c)
                embeddings_a_hat = self.growup(embeddings_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_a_hat))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                l1_loss = conf.l1_loss(embeddings_a_hat, embeddings_c)
                g_total_loss = g_loss + 10 * l1_loss
                g_total_loss.backward()

                # g_loss.backward()
                self.optimizer_g.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate_child(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def train_age_invariant(self, conf, epochs):
        '''
        Our method, without growup
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()
                self.schedule_lr2()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer2.zero_grad()
                embeddings_c = self.model(imgs_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                g_loss.backward()
                self.optimizer2.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def train_age_invariant2(self, conf, epochs):
        '''
        Our method, without growup, using paired dataset TODO
        '''
        self.model.train()
        running_loss = 0.
        l1_loss = 0
        for e in range(epochs):
            print('epoch {} started'.format(e))

            if e in self.milestones:
                self.schedule_lr()
                self.schedule_lr2()

            a_loader = iter(self.adult_loader)
            c_loader = iter(self.child_loader)
            for imgs, labels, ages in tqdm(iter(self.loader)):
                # loader : base loader that returns images with id
                # a_loader, c_loader : adult, child loader with same datasize
                # ages : 0 == child, 1== adult
                try:
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)
                except StopIteration:
                    a_loader = iter(self.adult_loader)
                    c_loader = iter(self.child_loader)
                    imgs_a, labels_a = next(a_loader)
                    imgs_c, labels_c = next(c_loader)

                imgs = imgs.to(conf.device)
                labels = labels.to(conf.device)
                imgs_a, labels_a = imgs_a.to(conf.device), labels_a.to(
                    conf.device).type(torch.float32)
                imgs_c, labels_c = imgs_c.to(conf.device), labels_c.to(
                    conf.device).type(torch.float32)
                bs_a = imgs_a.shape[0]

                imgs_ac = torch.cat([imgs_a, imgs_c], dim=0)

                ###########################
                #       Train head        #
                ###########################
                self.optimizer.zero_grad()

                embeddings = self.model(imgs)

                thetas = self.head(embeddings, labels)

                loss = conf.ce_loss(thetas, labels)
                loss.backward()
                running_loss += loss.item()
                self.optimizer.step()

                ##############################
                #    Train discriminator     #
                ##############################
                self.optimizer_d.zero_grad()
                _embeddings = self.model(imgs_ac)
                embeddings_a, embeddings_c = _embeddings[:bs_a], _embeddings[
                    bs_a:]

                labels_ac = torch.cat([labels_a, labels_c], dim=0)
                pred_a = torch.squeeze(self.discriminator(
                    embeddings_a))  # sperate since batchnorm exists
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                pred_ac = torch.cat([pred_a, pred_c], dim=0)
                d_loss = conf.ls_loss(pred_ac, labels_ac)
                d_loss.backward()
                self.optimizer_d.step()

                #############################
                #      Train genertator     #
                #############################
                self.optimizer2.zero_grad()
                embeddings_c = self.model(imgs_c)
                pred_c = torch.squeeze(self.discriminator(embeddings_c))
                labels_a = torch.ones_like(labels_c, dtype=torch.float)
                # generator should make child 1
                g_loss = conf.ls_loss(pred_c, labels_a)

                g_loss.backward()
                self.optimizer2.step()

                if self.step % self.board_loss_every == 0 and self.step != 0:  # XXX
                    print('tensorboard plotting....')
                    loss_board = running_loss / self.board_loss_every
                    self.writer.add_scalar('train_loss', loss_board, self.step)
                    self.writer.add_scalar('d_loss', d_loss, self.step)
                    self.writer.add_scalar('g_loss', g_loss, self.step)
                    self.writer.add_scalar('l1_loss', l1_loss, self.step)
                    running_loss = 0.

                if self.step % self.evaluate_every == 0 and self.step != 0:
                    print('evaluating....')
                    accuracy, best_threshold, roc_curve_tensor = self.evaluate(
                        conf, self.lfw, self.lfw_issame)
                    self.board_val('lfw', accuracy, best_threshold,
                                   roc_curve_tensor)
                    accuracy2, best_threshold2, roc_curve_tensor2 = self.evaluate(
                        conf, self.fgnetc, self.fgnetc_issame)
                    self.board_val('fgent_c', accuracy2, best_threshold2,
                                   roc_curve_tensor2)

                    self.model.train()

                if self.step % self.save_every == 0 and self.step != 0:
                    print('saving model....')
                    # save with most recently calculated accuracy?
                    self.save_state(conf, accuracy2, extra=str(conf.data_mode) + '_' + str(conf.net_depth) \
                        + '_' + str(conf.batch_size) + conf.model_name)

                self.step += 1
        self.save_state(conf, accuracy2, to_save_folder=True, extra=str(conf.data_mode)  + '_' + str(conf.net_depth)\
             + '_'+ str(conf.batch_size) +'_discriminator_final')

    def analyze_angle(self, conf, name):
        '''
        Only works on age labeled vgg dataset, agedb dataset
        '''

        angle_table = [{
            0: set(),
            1: set(),
            2: set(),
            3: set(),
            4: set(),
            5: set(),
            6: set(),
            7: set()
        } for i in range(self.class_num)]
        # batch = 0
        # _angle_table = torch.zeros(self.class_num, 8, len(self.loader)//conf.batch_size).to(conf.device)
        if conf.resume_analysis:
            self.loader = []
        for imgs, labels, ages in tqdm(iter(self.loader)):

            imgs = imgs.to(conf.device)
            labels = labels.to(conf.device)
            ages = ages.to(conf.device)

            embeddings = self.model(imgs)
            if conf.use_dp:
                kernel_norm = l2_norm(self.head.module.kernel, axis=0)
                cos_theta = torch.mm(embeddings, kernel_norm)
                cos_theta = cos_theta.clamp(-1, 1)
            else:
                cos_theta = self.head.get_angle(embeddings)

            thetas = torch.abs(torch.rad2deg(torch.acos(cos_theta)))

            for i in range(len(thetas)):
                age_bin = 7
                if ages[i] < 26:
                    age_bin = 0 if ages[i] < 13 else 1 if ages[i] < 19 else 2
                elif ages[i] < 66:
                    age_bin = int(((ages[i] + 4) // 10).item())
                angle_table[labels[i]][age_bin].add(
                    thetas[i][labels[i]].item())

        if conf.resume_analysis:
            with open('analysis/angle_table.pkl', 'rb') as f:
                angle_table = pickle.load(f)
        else:
            with open('analysis/angle_table.pkl', 'wb') as f:
                pickle.dump(angle_table, f)

        count, avg_angle = [], []
        for i in range(self.class_num):
            count.append(
                [len(single_set) for single_set in angle_table[i].values()])
            avg_angle.append([
                sum(list(single_set)) / len(single_set)
                if len(single_set) else 0  # if set() size is zero, avg is zero
                for single_set in angle_table[i].values()
            ])

        count_df = pd.DataFrame(count)
        avg_angle_df = pd.DataFrame(avg_angle)

        with pd.ExcelWriter('analysis/analyze_angle_{}_{}.xlsx'.format(
                conf.data_mode, name)) as writer:
            count_df.to_excel(writer, sheet_name='count')
            avg_angle_df.to_excel(writer, sheet_name='avg_angle')

    def schedule_lr(self):
        for params in self.optimizer.param_groups:
            params['lr'] /= 10
        print(self.optimizer)

    def schedule_lr2(self):
        for params in self.optimizer2.param_groups:
            params['lr'] /= 10
        print(self.optimizer2)

    def infer(self, conf, faces, target_embs, tta=False):
        '''
        faces : list of PIL Image
        target_embs : [n, 512] computed embeddings of faces in facebank
        names : recorded names of faces in facebank
        tta : test time augmentation (hfilp, that's all)
        '''
        embs = []
        for img in faces:
            if tta:
                mirror = transforms.functional.hflip(img)
                emb = self.model(
                    conf.test_transform(img).to(conf.device).unsqueeze(0))
                emb_mirror = self.model(
                    conf.test_transform(mirror).to(conf.device).unsqueeze(0))
                embs.append(l2_norm(emb + emb_mirror))
            else:
                embs.append(
                    self.model(
                        conf.test_transform(img).to(conf.device).unsqueeze(0)))
        source_embs = torch.cat(embs)

        diff = source_embs.unsqueeze(-1) - target_embs.transpose(
            1, 0).unsqueeze(0)
        dist = torch.sum(torch.pow(diff, 2), dim=1)
        minimum, min_idx = torch.min(dist, dim=1)
        min_idx[minimum > self.threshold] = -1  # if no match, set idx to -1
        return min_idx, minimum

    def save_best_state(self,
                        conf,
                        accuracy,
                        to_save_folder=False,
                        extra=None,
                        model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        os.makedirs('work_space/models', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            str(save_path) +
            ('lfw_best_model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                str(save_path) +
                ('lfw_best_head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                str(save_path) +
                ('lfw_best_optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))

    def save_state(self,
                   conf,
                   accuracy,
                   to_save_folder=False,
                   extra=None,
                   model_only=False):
        if to_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path

        os.makedirs('work_space/models', exist_ok=True)
        torch.save(
            self.model.state_dict(),
            str(save_path) +
            ('/model_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                get_time(), accuracy, self.step, extra)))
        if not model_only:
            torch.save(
                self.head.state_dict(),
                str(save_path) +
                ('/head_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            torch.save(
                self.optimizer.state_dict(),
                str(save_path) +
                ('/optimizer_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                    get_time(), accuracy, self.step, extra)))
            if conf.discriminator:
                torch.save(
                    self.growup.state_dict(),
                    str(save_path) +
                    ('/growup_{}_accuracy:{:.3f}_step:{}_{}.pth'.format(
                        get_time(), accuracy, self.step, extra)))

    def load_state(self,
                   conf,
                   fixed_str,
                   from_save_folder=False,
                   model_only=False,
                   analyze=False):
        if from_save_folder:
            save_path = conf.save_path
        else:
            save_path = conf.model_path
        self.model.load_state_dict(
            torch.load(os.path.join(save_path, 'model_{}'.format(fixed_str))))
        if not model_only:
            self.head.load_state_dict(
                torch.load(save_path / 'head_{}'.format(fixed_str)))
            if not analyze:
                self.optimizer.load_state_dict(
                    torch.load(save_path / 'optimizer_{}'.format(fixed_str)))
Exemple #3
0
            # bert_optimizer.step()
            # gen_optimizer.step()
            # dis_optimizer.step()

            tr_g_loss += g_loss.item()
            tr_d_loss += d_loss.item()
            nb_tr_examples += src_input_ids.size(0)
            nb_tr_steps += 1
            global_step += 1

        tr_g_loss /= nb_tr_steps
        tr_d_loss /= nb_tr_steps

        # VALIDATION
        bert.eval()
        discriminator.eval()

        all_preds = np.array([])
        all_label_ids = np.array([])
        eval_loss = 0
        nb_eval_steps = 0
        for src_input_ids, src_input_mask, label_ids in val_dataloader:
            src_input_ids = src_input_ids.to(device)
            src_input_mask = src_input_mask.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                _, doc_rep = bert(src_input_ids, attention_mask=src_input_mask)
                _, logits, probs = discriminator(doc_rep)
                print(probs)
                probs = torch.nn.functional.softmax(probs[:, :-1], dim=-1)
Exemple #4
0
class Solver(object):

    ####
    def __init__(self, args):

        self.args = args

        self.name = ( '%s_etaS_%s_etaH_%s_lamklMin_%s_lamklMax_%s' + \
                      '_gamma_%s_zDim_%s' ) % \
            ( args.dataset, args.etaS, args.etaH, \
              args.lamklMin, args.lamklMax, args.gamma, args.z_dim )
        # to be appended by run_id

        self.use_cuda = args.cuda and torch.cuda.is_available()

        self.max_iter = int(args.max_iter)

        # do it every specified iters
        self.print_iter = args.print_iter
        self.ckpt_save_iter = args.ckpt_save_iter
        self.output_save_iter = args.output_save_iter

        # data info
        self.dset_dir = args.dset_dir
        self.dataset = args.dataset
        if args.dataset.endswith('dsprites'):
            self.nc = 1
        else:
            self.nc = 3

        # groundtruth factor labels (only available for "dsprites")
        if self.dataset == 'dsprites':

            # latent factor = (color, shape, scale, orient, pos-x, pos-y)
            #   color = {1} (1)
            #   shape = {1=square, 2=oval, 3=heart} (3)
            #   scale = {0.5, 0.6, ..., 1.0} (6)
            #   orient = {2*pi*(k/39)}_{k=0}^39 (40)
            #   pos-x = {k/31}_{k=0}^31 (32)
            #   pos-y = {k/31}_{k=0}^31 (32)
            # (number of variations = 1*3*6*40*32*32 = 737280)

            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[:, [1, 2, 3, 4, 5]]
            # latent values (actual values);(737280 x 5)
            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            self.latent_classes = latent_classes[:, [1, 2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (737280 x 5)
            self.latent_sizes = np.array([3, 6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # groundtruth factor labels
        elif self.dataset == 'oval_dsprites':

            latent_classes = np.load(os.path.join(self.dset_dir,
                                                  'dsprites-dataset',
                                                  'latents_classes.npy'),
                                     encoding='latin1')
            idx = np.where(latent_classes[:, 1] == 1)[0]  # "oval" shape only
            self.latent_classes = latent_classes[idx, :]
            self.latent_classes = self.latent_classes[:, [2, 3, 4, 5]]
            # classes ({0,1,...,K}-valued); (245760 x 4)
            latent_values = np.load(os.path.join(self.dset_dir,
                                                 'dsprites-dataset',
                                                 'latents_values.npy'),
                                    encoding='latin1')
            self.latent_values = latent_values[idx, :]
            self.latent_values = self.latent_values[:, [2, 3, 4, 5]]
            # latent values (actual values);(245760 x 4)

            self.latent_sizes = np.array([6, 40, 32, 32])
            self.N = self.latent_values.shape[0]

            if args.eval_metrics:
                self.eval_metrics = True
                self.eval_metrics_iter = args.eval_metrics_iter

        # networks and optimizers
        self.batch_size = args.batch_size
        self.z_dim = args.z_dim
        self.etaS = args.etaS
        self.etaH = args.etaH
        self.lamklMin = args.lamklMin
        self.lamklMax = args.lamklMax
        self.gamma = args.gamma
        self.lr_VAE = args.lr_VAE
        self.beta1_VAE = args.beta1_VAE
        self.beta2_VAE = args.beta2_VAE
        #        self.lr_rvec = args.lr_rvec
        #        self.beta1_rvec = args.beta1_rvec
        #        self.beta2_rvec = args.beta2_rvec
        self.lr_D = args.lr_D
        self.beta1_D = args.beta1_D
        self.beta2_D = args.beta2_D

        # visdom setup
        self.viz_on = args.viz_on
        if self.viz_on:

            self.win_id = dict(DZ='win_DZ',
                               recon='win_recon',
                               kl='win_kl',
                               rvS='win_rvS',
                               rvH='win_rvH')
            self.line_gather = DataGather('iter', 'p_DZ', 'p_DZ_perm', 'recon',
                                          'kl', 'rvS', 'rvH')

            if self.eval_metrics:
                self.win_id['metrics'] = 'win_metrics'

            import visdom

            self.viz_port = args.viz_port  # port number, eg, 8097
            self.viz = visdom.Visdom(port=self.viz_port)
            self.viz_ll_iter = args.viz_ll_iter
            self.viz_la_iter = args.viz_la_iter

            self.viz_init()

        # create dirs: "records", "ckpts", "outputs" (if not exist)
        mkdirs("records")
        mkdirs("ckpts")
        mkdirs("outputs")

        # set run id
        if args.run_id < 0:  # create a new id
            k = 0
            rfname = os.path.join("records", self.name + '_run_0.txt')
            while os.path.exists(rfname):
                k += 1
                rfname = os.path.join("records", self.name + '_run_%d.txt' % k)
            self.run_id = k
        else:  # user-provided id
            self.run_id = args.run_id

        # finalize name
        self.name = self.name + '_run_' + str(self.run_id)

        # records (text file to store console outputs)
        self.record_file = 'records/%s.txt' % self.name

        # checkpoints
        self.ckpt_dir = os.path.join("ckpts", self.name)

        # outputs
        self.output_dir_recon = os.path.join("outputs", self.name + '_recon')
        # dir for reconstructed images
        self.output_dir_synth = os.path.join("outputs", self.name + '_synth')
        # dir for synthesized images
        self.output_dir_trvsl = os.path.join("outputs", self.name + '_trvsl')
        # dir for latent traversed images

        #### create a new model or load a previously saved model

        self.ckpt_load_iter = args.ckpt_load_iter

        if self.ckpt_load_iter == 0:  # create a new model

            # create a vae model
            if args.dataset.endswith('dsprites'):
                self.encoder = Encoder1(self.z_dim)
                self.decoder = Decoder1(self.z_dim)
            else:
                pass  #self.VAE = FactorVAE2(self.z_dim)

            # create a relevance vector
            self.rvec = RelevanceVector(self.z_dim)

            # create a discriminator model
            self.D = Discriminator(self.z_dim)

        else:  # load a previously saved model

            print('Loading saved models (iter: %d)...' % self.ckpt_load_iter)
            self.load_checkpoint()
            print('...done')

        if self.use_cuda:
            print('Models moved to GPU...')
            self.encoder = self.encoder.cuda()
            self.decoder = self.decoder.cuda()
            self.rvec = self.rvec.cuda()
            self.D = self.D.cuda()
            print('...done')

        # get VAE parameters (and rv parameters)
        vae_params = list(self.encoder.parameters()) + \
          list(self.decoder.parameters()) + list(self.rvec.parameters())

        # get discriminator parameters
        dis_params = list(self.D.parameters())

        # create optimizers
        self.optim_vae = optim.Adam(vae_params,
                                    lr=self.lr_VAE,
                                    betas=[self.beta1_VAE, self.beta2_VAE])
        self.optim_dis = optim.Adam(dis_params,
                                    lr=self.lr_D,
                                    betas=[self.beta1_D, self.beta2_D])

    ####
    def train(self):

        self.set_mode(train=True)

        ones = torch.ones(self.batch_size, dtype=torch.long)
        zeros = torch.zeros(self.batch_size, dtype=torch.long)
        if self.use_cuda:
            ones = ones.cuda()
            zeros = zeros.cuda()

        # prepare dataloader (iterable)
        print('Start loading data...')
        self.data_loader = create_dataloader(self.args)
        print('...done')

        # iterators from dataloader
        iterator1 = iter(self.data_loader)
        iterator2 = iter(self.data_loader)

        iter_per_epoch = min(len(iterator1), len(iterator2))

        start_iter = self.ckpt_load_iter + 1
        epoch = int(start_iter / iter_per_epoch)

        for iteration in range(start_iter, self.max_iter + 1):

            # reset data iterators for each epoch
            if iteration % iter_per_epoch == 0:
                print('==== epoch %d done ====' % epoch)
                epoch += 1
                iterator1 = iter(self.data_loader)
                iterator2 = iter(self.data_loader)

            #============================================
            #          TRAIN THE VAE (ENC & DEC)
            #============================================

            # sample a mini-batch
            X, ids = next(iterator1)  # (n x C x H x W)
            if self.use_cuda:
                X = X.cuda()

            # enc(X)
            mu, std, logvar = self.encoder(X)

            # relevance vector
            rvlogit, rv = self.rvec()

            # kl loss
            kls = -0.5 * (1 + logvar - mu**2 - std**2)  # (n x z_dim)
            klsum = kls.sum(1).mean()
            lamkl = self.lamklMax - (self.lamklMax - self.lamklMin) * rv
            loss_kl = (lamkl * kls).sum(1).mean()

            # reparam'ed samples
            if self.use_cuda:
                Eps = torch.cuda.FloatTensor(mu.shape).normal_()
            else:
                Eps = torch.randn(mu.shape)
            Z = mu + Eps * std

            # dec(Z)
            X_recon = self.decoder(Z)

            # recon loss
            loss_recon = F.binary_cross_entropy_with_logits(
                X_recon, X, reduction='sum').div(X.size(0))

            # dis(rv*Z)
            DZ = self.D(rv * Z)

            # tc loss
            loss_tc = (DZ[:, 0] - DZ[:, 1]).mean()

            # L1 (sparseness) loss
            loss_sparse = rv.sum()

            # entropy loss
            loss_entropy = F.binary_cross_entropy_with_logits(rvlogit,
                                                              rv,
                                                              reduction='sum')
            #loss_entropy = (rv*(1-rv)).sum()

            # total loss for vae
            vae_loss = loss_recon + loss_kl + self.gamma*loss_tc + \
                       self.etaS*loss_sparse + self.etaH*loss_entropy

            # update vae
            self.optim_vae.zero_grad()
            vae_loss.backward()
            self.optim_vae.step()

            #============================================
            #          TRAIN THE DISCRIMINATOR
            #============================================

            # sample a mini-batch
            X2, ids = next(iterator2)  # (n x C x H x W)
            if self.use_cuda:
                X2 = X2.cuda()

            # enc(X2)
            mu, std, _ = self.encoder(X2)

            # reparam'ed samples
            if self.use_cuda:
                Eps = torch.cuda.FloatTensor(mu.shape).normal_()
            else:
                Eps = torch.randn(mu.shape)
            Z = mu + Eps * std

            # relevance vector
            _, rv = self.rvec()

            RZ = rv * Z

            # dis(RZ)
            DZ = self.D(RZ)

            # dim-wise permutated Z over the mini-batch
            perm_Z = []
            for zj in RZ.split(1, 1):
                idx = torch.randperm(Z.size(0))
                perm_zj = zj[idx]
                perm_Z.append(perm_zj)
            RZ_perm = torch.cat(perm_Z, 1)
            RZ_perm = RZ_perm.detach()

            # dis(RZ_perm)
            DZ_perm = self.D(RZ_perm)

            # discriminator loss
            dis_loss = 0.5 * (F.cross_entropy(DZ, zeros) +
                              F.cross_entropy(DZ_perm, ones))

            # update discriminator
            self.optim_dis.zero_grad()
            dis_loss.backward()
            self.optim_dis.step()

            # print the losses
            if iteration % self.print_iter == 0:
                prn_str = ( '[iter %d (epoch %d)] vae_loss: %.3f | ' + \
                    'dis_loss: %.3f\n    ' + \
                    '(recon: %.3f, kl: %.3f, tc: %.3f, L1: %.3f, H: %.3f)' \
                  ) % \
                  ( iteration, epoch, vae_loss.item(), dis_loss.item(),
                    loss_recon.item(), klsum.item(), loss_tc.item(),
                    loss_sparse.item(), loss_entropy.item() )
                prn_str += '\n    rv = {}'.format(
                    rv.detach().cpu().numpy().round(2))
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str, ))
                    record.close()

            # save model parameters
            if iteration % self.ckpt_save_iter == 0:
                self.save_checkpoint(iteration)

            # save output images (recon, synth, etc.)
            if iteration % self.output_save_iter == 0:

                # 1) save the recon images
                self.save_recon(iteration, X, torch.sigmoid(X_recon).data)

                # 2) save the synth images
                self.save_synth(iteration, howmany=100)

                # 3) save the latent traversed images
                if self.dataset.lower() == '3dchairs':
                    self.save_traverse(iteration, limb=-2, limu=2, inter=0.5)
                else:
                    self.save_traverse(iteration, limb=-3, limu=3, inter=0.1)

            # (visdom) insert current line stats
            if self.viz_on and (iteration % self.viz_ll_iter == 0):

                # compute discriminator accuracy
                p_DZ = F.softmax(DZ, 1)[:, 0].detach()
                p_DZ_perm = F.softmax(DZ_perm, 1)[:, 0].detach()

                # insert line stats
                self.line_gather.insert(iter=iteration,
                                        p_DZ=p_DZ.mean().item(),
                                        p_DZ_perm=p_DZ_perm.mean().item(),
                                        recon=loss_recon.item(),
                                        kl=klsum.item(),
                                        rvS=loss_sparse.item(),
                                        rvH=loss_entropy.item())

            # (visdom) visualize line stats (then flush out)
            if self.viz_on and (iteration % self.viz_la_iter == 0):
                self.visualize_line()
                self.line_gather.flush()

            # evaluate metrics
            if self.eval_metrics and (iteration % self.eval_metrics_iter == 0):

                metric1, _ = self.eval_disentangle_metric1()
                metric2, _ = self.eval_disentangle_metric2()

                prn_str = ( '********\n[iter %d (epoch %d)] ' + \
                  'metric1 = %.4f, metric2 = %.4f\n********' ) % \
                  (iteration, epoch, metric1, metric2)
                print(prn_str)
                if self.record_file:
                    record = open(self.record_file, 'a')
                    record.write('%s\n' % (prn_str, ))
                    record.close()

                # (visdom) visulaize metrics
                if self.viz_on:
                    self.visualize_line_metrics(iteration, metric1, metric2)

    ####
    def eval_disentangle_metric1(self):

        # some hyperparams
        num_pairs = 800  # # data pairs (d,y) for majority vote classification
        bs = 50  # batch size
        nsamps_per_factor = 100  # samples per factor
        nsamps_agn_factor = 5000  # factor-agnostic samples

        self.set_mode(train=False)

        # 1) estimate variances of latent points factor agnostic

        dl = DataLoader(self.data_loader.dataset,
                        batch_size=bs,
                        shuffle=True,
                        num_workers=self.args.num_workers,
                        pin_memory=True)
        iterator = iter(dl)

        M = []
        for ib in range(int(nsamps_agn_factor / bs)):

            # sample a mini-batch
            Xb, _ = next(iterator)  # (bs x C x H x W)
            if self.use_cuda:
                Xb = Xb.cuda()

            # enc(Xb)
            mub, _, _ = self.encoder(Xb)  # (bs x z_dim)

            M.append(mub.cpu().detach().numpy())

        M = np.concatenate(M, 0)

        # estimate sample vairance and mean of latent points for each dim
        vars_agn_factor = np.var(M, 0)

        # 2) estimatet dim-wise vars of latent points with "one factor fixed"

        factor_ids = range(0, len(self.latent_sizes))  # true factor ids
        vars_per_factor = np.zeros([num_pairs, self.z_dim])
        true_factor_ids = np.zeros(num_pairs, np.int)  # true factor ids

        # prepare data pairs for majority-vote classification
        i = 0
        for j in factor_ids:  # for each factor

            # repeat num_paris/num_factors times
            for r in range(int(num_pairs / len(factor_ids))):

                # a true factor (id and class value) to fix
                fac_id = j
                fac_class = np.random.randint(self.latent_sizes[fac_id])

                # randomly select images (with the fixed factor)
                indices = np.where(self.latent_classes[:,
                                                       fac_id] == fac_class)[0]
                np.random.shuffle(indices)
                idx = indices[:nsamps_per_factor]
                M = []
                for ib in range(int(nsamps_per_factor / bs)):
                    Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]]
                    if Xb.shape[0] < 1:  # no more samples
                        continue
                    if self.use_cuda:
                        Xb = Xb.cuda()
                    mub, _, _ = self.encoder(Xb)  # (bs x z_dim)
                    M.append(mub.cpu().detach().numpy())
                M = np.concatenate(M, 0)

                # estimate sample var and mean of latent points for each dim
                if M.shape[0] >= 2:
                    vars_per_factor[i, :] = np.var(M, 0)
                else:  # not enough samples to estimate variance
                    vars_per_factor[i, :] = 0.0

                # true factor id (will become the class label)
                true_factor_ids[i] = fac_id

                i += 1

        # 3) evaluate majority vote classification accuracy

        # inputs in the paired data for classification
        smallest_var_dims = np.argmin(vars_per_factor /
                                      (vars_agn_factor + 1e-20),
                                      axis=1)

        # contingency table
        C = np.zeros([self.z_dim, len(factor_ids)])
        for i in range(num_pairs):
            C[smallest_var_dims[i], true_factor_ids[i]] += 1

        num_errs = 0  # # misclassifying errors of majority vote classifier
        for k in range(self.z_dim):
            num_errs += np.sum(C[k, :]) - np.max(C[k, :])

        metric1 = (num_pairs - num_errs) / num_pairs  # metric = accuracy

        self.set_mode(train=True)

        return metric1, C

    ####
    def eval_disentangle_metric2(self):

        # some hyperparams
        num_pairs = 800  # # data pairs (d,y) for majority vote classification
        bs = 50  # batch size
        nsamps_per_factor = 100  # samples per factor
        nsamps_agn_factor = 5000  # factor-agnostic samples

        self.set_mode(train=False)

        # 1) estimate variances of latent points factor agnostic

        dl = DataLoader(self.data_loader.dataset,
                        batch_size=bs,
                        shuffle=True,
                        num_workers=self.args.num_workers,
                        pin_memory=True)
        iterator = iter(dl)

        M = []
        for ib in range(int(nsamps_agn_factor / bs)):

            # sample a mini-batch
            Xb, _ = next(iterator)  # (bs x C x H x W)
            if self.use_cuda:
                Xb = Xb.cuda()

            # enc(Xb)
            mub, _, _ = self.encoder(Xb)  # (bs x z_dim)

            M.append(mub.cpu().detach().numpy())

        M = np.concatenate(M, 0)

        # estimate sample vairance and mean of latent points for each dim
        vars_agn_factor = np.var(M, 0)

        # 2) estimatet dim-wise vars of latent points with "one factor varied"

        factor_ids = range(0, len(self.latent_sizes))  # true factor ids
        vars_per_factor = np.zeros([num_pairs, self.z_dim])
        true_factor_ids = np.zeros(num_pairs, np.int)  # true factor ids

        # prepare data pairs for majority-vote classification
        i = 0
        for j in factor_ids:  # for each factor

            # repeat num_paris/num_factors times
            for r in range(int(num_pairs / len(factor_ids))):

                # randomly choose true factors (id's and class values) to fix
                fac_ids = list(np.setdiff1d(factor_ids, j))
                fac_classes = \
                  [ np.random.randint(self.latent_sizes[k]) for k in fac_ids ]

                # randomly select images (with the other factors fixed)
                if len(fac_ids) > 1:
                    indices = np.where(
                        np.sum(self.latent_classes[:, fac_ids] == fac_classes,
                               1) == len(fac_ids))[0]
                else:
                    indices = np.where(
                        self.latent_classes[:, fac_ids] == fac_classes)[0]
                np.random.shuffle(indices)
                idx = indices[:nsamps_per_factor]
                M = []
                for ib in range(int(nsamps_per_factor / bs)):
                    Xb, _ = dl.dataset[idx[(ib * bs):(ib + 1) * bs]]
                    if Xb.shape[0] < 1:  # no more samples
                        continue
                    if self.use_cuda:
                        Xb = Xb.cuda()
                    mub, _, _ = self.encoder(Xb)  # (bs x z_dim)
                    M.append(mub.cpu().detach().numpy())
                M = np.concatenate(M, 0)

                # estimate sample var and mean of latent points for each dim
                if M.shape[0] >= 2:
                    vars_per_factor[i, :] = np.var(M, 0)
                else:  # not enough samples to estimate variance
                    vars_per_factor[i, :] = 0.0

                # true factor id (will become the class label)
                true_factor_ids[i] = j

                i += 1

        # 3) evaluate majority vote classification accuracy

        # inputs in the paired data for classification
        largest_var_dims = np.argmax(vars_per_factor /
                                     (vars_agn_factor + 1e-20),
                                     axis=1)

        # contingency table
        C = np.zeros([self.z_dim, len(factor_ids)])
        for i in range(num_pairs):
            C[largest_var_dims[i], true_factor_ids[i]] += 1

        num_errs = 0  # # misclassifying errors of majority vote classifier
        for k in range(self.z_dim):
            num_errs += np.sum(C[k, :]) - np.max(C[k, :])

        metric2 = (num_pairs - num_errs) / num_pairs  # metric = accuracy

        self.set_mode(train=True)

        return metric2, C

    ####
    def save_recon(self, iters, true_images, recon_images):

        # make a merge of true and recon, eg,
        #   merged[0,...] = true[0,...],
        #   merged[1,...] = recon[0,...],
        #   merged[2,...] = true[1,...],
        #   merged[3,...] = recon[1,...], ...

        n = true_images.shape[0]
        perm = torch.arange(0, 2 * n).view(2, n).transpose(1, 0)
        perm = perm.contiguous().view(-1)
        merged = torch.cat([true_images, recon_images], dim=0)
        merged = merged[perm, :].cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_recon, 'recon_%s.jpg' % iters)
        mkdirs(self.output_dir_recon)
        save_image(tensor=merged,
                   filename=fname,
                   nrow=2 * int(np.sqrt(n)),
                   pad_value=1)

    ####
    def save_synth(self, iters, howmany=100):

        self.set_mode(train=False)

        decoder = self.decoder

        Z = torch.randn(howmany, self.z_dim)
        if self.use_cuda:
            Z = Z.cuda()

        # do synthesis
        X = torch.sigmoid(decoder(Z)).data.cpu()

        # save the results as image
        fname = os.path.join(self.output_dir_synth, 'synth_%s.jpg' % iters)
        mkdirs(self.output_dir_synth)
        save_image(tensor=X,
                   filename=fname,
                   nrow=int(np.sqrt(howmany)),
                   pad_value=1)

        self.set_mode(train=True)

    ####
    def save_traverse(self, iters, limb=-3, limu=3, inter=2 / 3, loc=-1):

        self.set_mode(train=False)

        encoder = self.encoder
        decoder = self.decoder
        interpolation = torch.arange(limb, limu + 0.001, inter)

        i = np.random.randint(self.N)
        random_img = self.data_loader.dataset.__getitem__(i)[0]
        if self.use_cuda:
            random_img = random_img.cuda()
        random_img = random_img.unsqueeze(0)
        random_img_zmu, _, _ = encoder(random_img)

        if self.dataset.lower() == 'dsprites':

            fixed_idx1 = 87040  # square
            fixed_idx2 = 332800  # ellipse
            fixed_idx3 = 578560  # heart

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed_square': fixed_img1,
                'fixed_ellipse': fixed_img2,
                'fixed_heart': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed_square': fixed_img_zmu1,
                'fixed_ellipse': fixed_img_zmu2,
                'fixed_heart': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

        elif self.dataset.lower() == 'oval_dsprites':

            fixed_idx1 = 87040  # oval1
            fixed_idx2 = 220045  # oval2
            fixed_idx3 = 178560  # oval3

            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
            if self.use_cuda:
                fixed_img1 = fixed_img1.cuda()
            fixed_img1 = fixed_img1.unsqueeze(0)
            fixed_img_zmu1, _, _ = encoder(fixed_img1)

            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
            if self.use_cuda:
                fixed_img2 = fixed_img2.cuda()
            fixed_img2 = fixed_img2.unsqueeze(0)
            fixed_img_zmu2, _, _ = encoder(fixed_img2)

            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
            if self.use_cuda:
                fixed_img3 = fixed_img3.cuda()
            fixed_img3 = fixed_img3.unsqueeze(0)
            fixed_img_zmu3, _, _ = encoder(fixed_img3)

            IMG = {
                'fixed1': fixed_img1,
                'fixed2': fixed_img2,
                'fixed3': fixed_img3,
                'random_img': random_img
            }

            Z = {
                'fixed1': fixed_img_zmu1,
                'fixed2': fixed_img_zmu2,
                'fixed3': fixed_img_zmu3,
                'random_img': random_img_zmu
            }

#        elif self.dataset.lower() == 'celeba':
#
#            fixed_idx1 = 191281 # 'CelebA/img_align_celeba/191282.jpg'
#            fixed_idx2 = 143307 # 'CelebA/img_align_celeba/143308.jpg'
#            fixed_idx3 = 101535 # 'CelebA/img_align_celeba/101536.jpg'
#            fixed_idx4 = 70059  # 'CelebA/img_align_celeba/070060.jpg'
#
#            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
#            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
#            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
#
#            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
#            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
#            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
#
#            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
#            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
#            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
#
#            fixed_img4 = self.data_loader.dataset.__getitem__(fixed_idx4)[0]
#            fixed_img4 = fixed_img4.to(self.device).unsqueeze(0)
#            fixed_img_z4 = encoder(fixed_img4)[:, :self.z_dim]
#
#            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
#                 'fixed_3':fixed_img_z3, 'fixed_4':fixed_img_z4,
#                 'random':random_img_zmu}
#
#        elif self.dataset.lower() == '3dchairs':
#
#            fixed_idx1 = 40919 # 3DChairs/images/4682_image_052_p030_t232_r096.png
#            fixed_idx2 = 5172  # 3DChairs/images/14657_image_020_p020_t232_r096.png
#            fixed_idx3 = 22330 # 3DChairs/images/30099_image_052_p030_t232_r096.png
#
#            fixed_img1 = self.data_loader.dataset.__getitem__(fixed_idx1)[0]
#            fixed_img1 = fixed_img1.to(self.device).unsqueeze(0)
#            fixed_img_z1 = encoder(fixed_img1)[:, :self.z_dim]
#
#            fixed_img2 = self.data_loader.dataset.__getitem__(fixed_idx2)[0]
#            fixed_img2 = fixed_img2.to(self.device).unsqueeze(0)
#            fixed_img_z2 = encoder(fixed_img2)[:, :self.z_dim]
#
#            fixed_img3 = self.data_loader.dataset.__getitem__(fixed_idx3)[0]
#            fixed_img3 = fixed_img3.to(self.device).unsqueeze(0)
#            fixed_img_z3 = encoder(fixed_img3)[:, :self.z_dim]
#
#            Z = {'fixed_1':fixed_img_z1, 'fixed_2':fixed_img_z2,
#                 'fixed_3':fixed_img_z3, 'random':random_img_zmu}
#
        else:

            raise NotImplementedError

        # do traversal and collect generated images
        gifs = []
        for key in Z:
            z_ori = Z[key]
            for row in range(self.z_dim):
                if loc != -1 and row != loc:
                    continue
                z = z_ori.clone()
                for val in interpolation:
                    z[:, row] = val
                    sample = torch.sigmoid(decoder(z)).data
                    gifs.append(sample)

        # save the generated files, also the animated gifs
        out_dir = os.path.join(self.output_dir_trvsl, str(iters))
        mkdirs(self.output_dir_trvsl)
        mkdirs(out_dir)
        gifs = torch.cat(gifs)
        gifs = gifs.view(len(Z), self.z_dim, len(interpolation), self.nc, 64,
                         64).transpose(1, 2)
        for i, key in enumerate(Z.keys()):
            for j, val in enumerate(interpolation):
                I = torch.cat([IMG[key], gifs[i][j]], dim=0)
                save_image(tensor=I.cpu(),
                           filename=os.path.join(out_dir,
                                                 '%s_%03d.jpg' % (key, j)),
                           nrow=1 + self.z_dim,
                           pad_value=1)
            # make animated gif
            grid2gif(out_dir,
                     key,
                     str(os.path.join(out_dir, key + '.gif')),
                     delay=10)

        self.set_mode(train=True)

    ####
    def viz_init(self):

        self.viz.close(env=self.name + '/lines', win=self.win_id['DZ'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['recon'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['kl'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['rvS'])
        self.viz.close(env=self.name + '/lines', win=self.win_id['rvH'])

        if self.eval_metrics:
            self.viz.close(env=self.name + '/lines',
                           win=self.win_id['metrics'])

    ####
    def visualize_line(self):

        # prepare data to plot
        data = self.line_gather.data
        iters = torch.Tensor(data['iter'])
        recon = torch.Tensor(data['recon'])
        kl = torch.Tensor(data['kl'])
        rvS = torch.Tensor(data['rvS'])
        rvH = torch.Tensor(data['rvH'])

        p_DZ = torch.Tensor(data['p_DZ'])
        p_DZ_perm = torch.Tensor(data['p_DZ_perm'])
        p_DZs = torch.stack([p_DZ, p_DZ_perm], -1)  # (#items x 2)

        self.viz.line(X=iters,
                      Y=p_DZs,
                      env=self.name + '/lines',
                      win=self.win_id['DZ'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='D(z)',
                                title='Discriminator-Z',
                                legend=[
                                    'D(z)',
                                    'D(z_perm)',
                                ]))

        self.viz.line(X=iters,
                      Y=recon,
                      env=self.name + '/lines',
                      win=self.win_id['recon'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='recon loss',
                                title='Reconstruction'))

        self.viz.line(X=iters,
                      Y=kl,
                      env=self.name + '/lines',
                      win=self.win_id['kl'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='E_x[kl(q(z|x)||p(z)]',
                                title='KL divergence'))

        self.viz.line(X=iters,
                      Y=rvS,
                      env=self.name + '/lines',
                      win=self.win_id['rvS'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='||rv||_1',
                                title='L1 norm of relevance vector'))

        self.viz.line(X=iters,
                      Y=rvH,
                      env=self.name + '/lines',
                      win=self.win_id['rvH'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='H(rv)',
                                title='Entropy of relevance vector'))

    ####
    def visualize_line_metrics(self, iters, metric1, metric2):

        # prepare data to plot
        iters = torch.tensor([iters], dtype=torch.int64).detach()
        metric1 = torch.tensor([metric1])
        metric2 = torch.tensor([metric2])
        metrics = torch.stack([metric1.detach(), metric2.detach()], -1)

        self.viz.line(X=iters,
                      Y=metrics,
                      env=self.name + '/lines',
                      win=self.win_id['metrics'],
                      update='append',
                      opts=dict(xlabel='iter',
                                ylabel='metrics',
                                title='Disentanglement metrics',
                                legend=['metric1', 'metric2']))

    ####
    def set_mode(self, train=True):

        if train:
            self.encoder.train()
            self.decoder.train()
            self.D.train()
        else:
            self.encoder.eval()
            self.decoder.eval()
            self.D.eval()

    ####
    def save_checkpoint(self, iteration):

        encoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_encoder.pt' % iteration)
        decoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_decoder.pt' % iteration)
        rvec_path = os.path.join(self.ckpt_dir, 'iter_%s_rvec.pt' % iteration)
        D_path = os.path.join(self.ckpt_dir, 'iter_%s_D.pt' % iteration)

        mkdirs(self.ckpt_dir)

        torch.save(self.encoder, encoder_path)
        torch.save(self.decoder, decoder_path)
        torch.save(self.rvec, rvec_path)
        torch.save(self.D, D_path)

    ####
    def load_checkpoint(self):

        encoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_encoder.pt' % self.ckpt_load_iter)
        decoder_path = os.path.join(self.ckpt_dir,
                                    'iter_%s_decoder.pt' % self.ckpt_load_iter)
        rvec_path = os.path.join(self.ckpt_dir,
                                 'iter_%s_rvec.pt' % self.ckpt_load_iter)
        D_path = os.path.join(self.ckpt_dir,
                              'iter_%s_D.pt' % self.ckpt_load_iter)

        if self.use_cuda:
            self.encoder = torch.load(encoder_path)
            self.decoder = torch.load(decoder_path)
            self.rvec = torch.load(rvec_path)
            self.D = torch.load(D_path)
        else:
            self.encoder = torch.load(encoder_path, map_location='cpu')
            self.decoder = torch.load(decoder_path, map_location='cpu')
            self.rvec = torch.load(rvec_path, map_location='cpu')
            self.D = torch.load(D_path, map_location='cpu')
            # clip critic weights between -0.01, 0.01
            for p in critic.parameters():
                p.data.clamp_(-WEIGHT_CLIP, WEIGHT_CLIP)

        # Train Generator: max E[critic(gen_fake)] <-> min -E[critic(gen_fake)]
        gen_fake = critic(fake).reshape(-1)
        loss_gen = -torch.mean(gen_fake)
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0 and batch_idx > 0:
            gen.eval()
            critic.eval()
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_critic:.4f}, loss G: {loss_gen:.4f}")

            with torch.no_grad():
                fake = gen(noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(data[:32],
                                                            normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32],
                                                            normalize=True)

                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)
Exemple #6
0
def train(resume_path=None, jigsaw_path=None):

    writer = SummaryWriter('../runs/'+hparams.exp_name)

    for k in hparams.__dict__.keys():
        writer.add_text(str(k), str(hparams.__dict__[k]))

    train_dataset = ChestData(data_csv=hparams.train_csv, data_dir=hparams.train_dir, augment=hparams.augment,
                        transform=transforms.Compose([
                            transforms.Resize(hparams.image_shape),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
                        ]))

    validation_dataset = ChestData(data_csv=hparams.valid_csv, data_dir=hparams.valid_dir,
                        transform=transforms.Compose([
                            transforms.Resize(hparams.image_shape),
                            transforms.ToTensor(),
                            transforms.Normalize((0.5027, 0.5027, 0.5027), (0.2915, 0.2915, 0.2915))
                        ]))

    # train_sampler = WeightedRandomSampler()

    train_loader = DataLoader(train_dataset, batch_size=hparams.batch_size,
                            shuffle=True, num_workers=2)

    validation_loader = DataLoader(validation_dataset, batch_size=hparams.batch_size,
                            shuffle=True, num_workers=2)

    print('loaded train data of length : {}'.format(len(train_dataset)))

    adversarial_loss = torch.nn.BCELoss().to(hparams.gpu_device)
    discriminator = Discriminator().to(hparams.gpu_device)

#     if hparams.cuda:
#         discriminator = nn.DataParallel(discriminator, device_ids=hparams.device_ids)

    params_count = 0
    for param in discriminator.parameters():
        params_count += np.prod(param.size())
    print('Model has {0} trainable parameters'.format(params_count))

    if not hparams.pretrained:
#         discriminator.apply(weights_init_normal)
        pass
#     if jigsaw_path:
#         jigsaw = Jigsaw().to(hparams.gpu_device)
#         if hparams.cuda:
#             jigsaw = nn.DataParallel(jigsaw, device_ids=hparams.device_ids)
#         checkpoints = torch.load(jigsaw_path, map_location=hparams.gpu_device)
#         jigsaw.load_state_dict(checkpoints['discriminator_state_dict'])
#         discriminator.module.model.features = jigsaw.module.feature.features
#         print('loaded pretrained feature extractor from {} ..'.format(jigsaw_path))

    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=hparams.learning_rate, betas=(0.9, 0.999))

    scheduler_D = ReduceLROnPlateau(optimizer_D, mode='min', factor=0.3, patience=0, verbose=True, cooldown=0)

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    def validation(discriminator_, send_stats=False, epoch=0):
        print('Validating model on {0} examples. '.format(len(validation_dataset)))

        with torch.no_grad():
            pred_logits_list = []
            labels_list = []

            for (img, labels, imgs_names) in tqdm(validation_loader):
                img = Variable(img.float(), requires_grad=False)
                labels = Variable(labels.float(), requires_grad=False)

                img_ = img.to(hparams.gpu_device)
                labels = labels.to(hparams.gpu_device)
                
                pred_logits = discriminator_(img_)

                pred_logits_list.append(pred_logits)
                labels_list.append(labels)

            pred_logits = torch.cat(pred_logits_list, dim=0)
            labels = torch.cat(labels_list, dim=0)

            val_loss = adversarial_loss(pred_logits, labels)

        return accuracy_metrics(labels.long(), pred_logits), val_loss

    print('Starting training.. (log saved in:{})'.format(hparams.exp_name))
    start_time = time.time()
    best_valid_auc = 0

    # print(model)
    for epoch in range(hparams.num_epochs):
        for batch, (imgs, labels, imgs_name) in enumerate(tqdm(train_loader)):

            imgs = Variable(imgs.float(), requires_grad=False)
            labels = Variable(labels.float(), requires_grad=False)

            imgs_ = imgs.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            # ---------------------
            #  Train Discriminator
            # ---------------------
            optimizer_D.zero_grad()

            pred_logits, aux_logits = discriminator(imgs_)

            d_loss1 = adversarial_loss(pred_logits, labels)
            d_loss2 = adversarial_loss(aux_logits, labels)
            
            d_loss = d_loss1 + 0.4 * d_loss2

            d_loss.backward()
            optimizer_D.step()

            writer.add_scalar('d_loss', d_loss.item(), global_step=batch+epoch*len(train_loader))

            pred_labels = (pred_logits >= hparams.thresh)
            pred_labels = pred_labels.float()

            # if batch % hparams.print_interval == 0:
            #     auc, f1, acc, _, _ = accuracy_metrics(pred_labels, labels.long(), pred_logits)
            #     print('[Epoch - {0:.1f}, batch - {1:.3f}, d_loss - {2:.6f}, acc - {3:.4f}, f1 - {4:.5f}, auc - {5:.4f}]'.\
            #     format(1.0*epoch, 100.0*batch/len(train_loader), d_loss.item(), acc['avg'], f1[hparams.avg_mode], auc[hparams.avg_mode]))
        (val_auc, val_f1, val_acc, val_conf_mat, best_thresh), val_loss = validation(discriminator.eval(), epoch=epoch)
        discriminator = discriminator.train()
        for lbl in range(hparams.num_classes):
            fig = plot_cf(val_conf_mat[lbl])
            writer.add_figure('val_conf_{}'.format(hparams.id_to_class[lbl]), fig, global_step=epoch)
            plt.close(fig)
            writer.add_scalar('val_f1_{}'.format(hparams.id_to_class[lbl]), val_f1[lbl], global_step=epoch)
            writer.add_scalar('val_auc_{}'.format(hparams.id_to_class[lbl]), val_auc[lbl], global_step=epoch)
            writer.add_scalar('val_acc_{}'.format(hparams.id_to_class[lbl]), val_acc[lbl], global_step=epoch)
        writer.add_scalar('val_f1_{}'.format('micro'), val_f1['micro'], global_step=epoch)
        writer.add_scalar('val_auc_{}'.format('micro'), val_auc['micro'], global_step=epoch)
        writer.add_scalar('val_f1_{}'.format('macro'), val_f1['macro'], global_step=epoch)
        writer.add_scalar('val_auc_{}'.format('macro'), val_auc['macro'], global_step=epoch)
        writer.add_scalar('val_loss', val_loss, global_step=epoch)
        writer.add_scalar('val_f1', val_f1[hparams.avg_mode], global_step=epoch)
        writer.add_scalar('val_auc', val_auc[hparams.avg_mode], global_step=epoch)
        writer.add_scalar('val_acc', val_acc['avg'], global_step=epoch)
        scheduler_D.step(val_loss)
        writer.add_scalar('learning_rate', optimizer_D.param_groups[0]['lr'], global_step=epoch)

        torch.save({
            'epoch': epoch,
            'discriminator_state_dict': discriminator.state_dict(),
            'optimizer_D_state_dict': optimizer_D.state_dict(),
            }, hparams.model+'.'+str(epoch))
        if best_valid_auc <= val_auc[hparams.avg_mode]:
            best_valid_auc = val_auc[hparams.avg_mode]
            for lbl in range(hparams.num_classes):
                fig = plot_cf(val_conf_mat[lbl])
                writer.add_figure('best_val_conf_{}'.format(hparams.id_to_class[lbl]), fig, global_step=epoch)
                plt.close(fig)
            torch.save({
                'epoch': epoch,
                'discriminator_state_dict': discriminator.state_dict(),
                'optimizer_D_state_dict': optimizer_D.state_dict(),
                }, hparams.model+'.best')
            print('best model on validation set saved.')
        print('[Epoch - {0:.1f} ---> val_auc - {1:.4f}, current_lr - {2:.6f}, val_loss - {3:.4f}, best_val_auc - {4:.4f}, val_acc - {5:.4f}, val_f1 - {6:.4f}] - time - {7:.1f}'\
            .format(1.0*epoch, val_auc[hparams.avg_mode], optimizer_D.param_groups[0]['lr'], val_loss, best_valid_auc, val_acc['avg'], val_f1[hparams.avg_mode], time.time()-start_time))
        start_time = time.time()
Exemple #7
0
class BiAAE(object):
    def __init__(self, params):

        self.params = params
        self.tune_dir = "{}/{}-{}/{}".format(params.exp_id, params.src_lang,
                                             params.tgt_lang,
                                             params.norm_embeddings)
        self.tune_best_dir = "{}/best".format(self.tune_dir)

        self.X_AE = AE(params)
        self.Y_AE = AE(params)
        self.D_X = Discriminator(input_size=params.d_input_size,
                                 hidden_size=params.d_hidden_size,
                                 output_size=params.d_output_size)
        self.D_Y = Discriminator(input_size=params.d_input_size,
                                 hidden_size=params.d_hidden_size,
                                 output_size=params.d_output_size)

        self.nets = [self.X_AE, self.Y_AE, self.D_X, self.D_Y]
        self.loss_fn = torch.nn.BCELoss()
        self.loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    def weights_init(self, m):  # 正交初始化
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init2(self, m):  # xavier_normal 初始化
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_normal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init3(self, m):  # 单位阵初始化
        if isinstance(m, torch.nn.Linear):
            m.weight.data.copy_(
                torch.diag(torch.ones(self.params.g_input_size)))

    def freeze(self, m):
        for p in m.parameters():
            p.requires_grad = False

    def defreeze(self, m):
        for p in m.parameters():
            p.requires_grad = True

    def init_state(self, seed=-1):
        if torch.cuda.is_available():
            # Move the network and the optimizer to the GPU
            for net in self.nets:
                net.cuda()
            self.loss_fn = self.loss_fn.cuda()
            self.loss_fn2 = self.loss_fn2.cuda()

        print('Init3 the model...')
        self.X_AE.apply(self.weights_init)  # 可更改G初始化方式
        self.Y_AE.apply(self.weights_init)  # 可更改G初始化方式

        self.D_X.apply(self.weights_init2)
        #print(self.D_X.map1.weight)
        self.D_Y.apply(self.weights_init2)

    def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed):
        # Load data
        if not os.path.exists(self.params.data_dir):
            print("Data path doesn't exists: %s" % self.params.data_dir)
        if not os.path.exists(self.tune_dir):
            os.makedirs(self.tune_dir)
        if not os.path.exists(self.tune_best_dir):
            os.makedirs(self.tune_best_dir)

        src_word2id = src_dico[1]
        tgt_word2id = tgt_dico[1]
        en = src_emb
        it = tgt_emb

        #eval = Evaluator(self.params, en,it, torch.cuda.is_available())

        AE_optimizer = optim.SGD(filter(
            lambda p: p.requires_grad,
            list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),
                                 lr=self.params.g_learning_rate)
        D_optimizer = optim.SGD(list(self.D_X.parameters()) +
                                list(self.D_Y.parameters()),
                                lr=self.params.d_learning_rate)

        D_A_acc_epochs = []
        D_B_acc_epochs = []
        D_A_loss_epochs = []
        D_B_loss_epochs = []
        d_loss_epochs = []
        G_AB_loss_epochs = []
        G_BA_loss_epochs = []
        G_AB_recon_epochs = []
        G_BA_recon_epochs = []
        g_loss_epochs = []
        L_Z_loss_epoches = []

        acc_epochs = []

        criterion_epochs = []
        best_valid_metric = -100

        try:
            for epoch in range(self.params.num_epochs):
                D_A_losses = []
                D_B_losses = []
                G_AB_losses = []
                G_AB_recon = []
                G_BA_losses = []
                G_adv_losses = []
                G_BA_recon = []
                L_Z_losses = []
                d_losses = []
                g_losses = []
                hit_A = 0
                hit_B = 0
                total = 0
                start_time = timer()
                # lowest_loss = 1e5
                label_D = to_variable(
                    torch.FloatTensor(2 * self.params.mini_batch_size).zero_())
                label_D[:self.params.
                        mini_batch_size] = 1 - self.params.smoothing
                label_D[self.params.mini_batch_size:] = self.params.smoothing

                label_G = to_variable(
                    torch.FloatTensor(self.params.mini_batch_size).zero_())
                label_G = label_G + 1 - self.params.smoothing

                for mini_batch in range(
                        0, self.params.iters_in_epoch //
                        self.params.mini_batch_size):
                    for d_index in range(self.params.d_steps):
                        D_optimizer.zero_grad()  # Reset the gradients
                        self.D_X.train()
                        self.D_Y.train()

                        view_X, view_Y = self.get_batch_data_fast(en, it)

                        # Discriminator X
                        Y_Z = self.Y_AE.encode(view_Y).detach()
                        fake_X = self.X_AE.decode(Y_Z).detach()
                        input = torch.cat([view_X, fake_X], 0)

                        pred_A = self.D_X(input)
                        D_A_loss = self.loss_fn(pred_A, label_D)

                        # Discriminator Y
                        X_Z = self.X_AE.encode(view_X).detach()
                        fake_Y = self.Y_AE.decode(X_Z).detach()

                        input = torch.cat([view_Y, fake_Y], 0)
                        pred_B = self.D_Y(input)
                        D_B_loss = self.loss_fn(pred_B, label_D)

                        D_loss = D_A_loss + self.params.gate * D_B_loss

                        D_loss.backward(
                        )  # compute/store gradients, but don't change params
                        d_losses.append(to_numpy(D_loss.data))
                        D_A_losses.append(to_numpy(D_A_loss.data))
                        D_B_losses.append(to_numpy(D_B_loss.data))

                        discriminator_decision_A = to_numpy(pred_A.data)
                        hit_A += np.sum(
                            discriminator_decision_A[:self.params.
                                                     mini_batch_size] >= 0.5)
                        hit_A += np.sum(
                            discriminator_decision_A[self.params.
                                                     mini_batch_size:] < 0.5)

                        discriminator_decision_B = to_numpy(pred_B.data)
                        hit_B += np.sum(
                            discriminator_decision_B[:self.params.
                                                     mini_batch_size] >= 0.5)
                        hit_B += np.sum(
                            discriminator_decision_B[self.params.
                                                     mini_batch_size:] < 0.5)

                        D_optimizer.step(
                        )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                        # Clip weights
                        #_clip(self.D_X, self.params.clip_value)
                        #_clip(self.D_Y, self.params.clip_value)

                        sys.stdout.write(
                            "[%d/%d] :: Discriminator Loss: %.3f \r" %
                            (mini_batch, self.params.iters_in_epoch //
                             self.params.mini_batch_size,
                             np.asscalar(np.mean(d_losses))))
                        sys.stdout.flush()

                    total += 2 * self.params.mini_batch_size * self.params.d_steps

                    for g_index in range(self.params.g_steps):
                        # 2. Train G on D's response (but DO NOT train D on these labels)
                        AE_optimizer.zero_grad()
                        self.D_X.eval()
                        self.D_Y.eval()
                        view_X, view_Y = self.get_batch_data_fast(en, it)

                        # Generator X_AE
                        ## adversarial loss
                        X_Z = self.X_AE.encode(view_X)
                        X_recon = self.X_AE.decode(X_Z)
                        Y_fake = self.Y_AE.decode(X_Z)
                        pred_Y = self.D_Y(Y_fake)
                        L_adv_X = self.loss_fn(pred_Y, label_G)

                        L_recon_X = 1.0 - torch.mean(
                            self.loss_fn2(view_X, X_recon))

                        # Generator Y_AE
                        # adversarial loss
                        Y_Z = self.Y_AE.encode(view_Y)
                        Y_recon = self.Y_AE.decode(Y_Z)
                        X_fake = self.X_AE.decode(Y_Z)
                        pred_X = self.D_X(X_fake)
                        L_adv_Y = self.loss_fn(pred_X, label_G)

                        ### autoAE Loss
                        L_recon_Y = 1.0 - torch.mean(
                            self.loss_fn2(view_Y, Y_recon))

                        # cross-lingual Loss
                        L_Z = 1.0 - torch.mean(self.loss_fn2(X_Z, Y_Z))

                        G_loss = self.params.adv_weight * (self.params.gate*L_adv_X + L_adv_Y) + \
                                self.params.mono_weight * (L_recon_X+L_recon_Y) + \
                                self.params.cross_weight * L_Z

                        G_loss.backward()

                        g_losses.append(to_numpy(G_loss.data))
                        G_AB_losses.append(to_numpy(L_adv_X.data))
                        G_BA_losses.append(to_numpy(L_adv_Y.data))
                        G_adv_losses.append(
                            to_numpy(L_adv_Y.data + L_adv_X.data))
                        G_AB_recon.append(to_numpy(L_recon_X.data))
                        G_BA_recon.append(to_numpy(L_recon_Y.data))
                        L_Z_losses.append(to_numpy(L_Z.data))

                        AE_optimizer.step()  # Only optimizes G's parameters

                        sys.stdout.write(
                            "[%d/%d] ::                                     Generator Loss: %.3f \r"
                            % (mini_batch, self.params.iters_in_epoch //
                               self.params.mini_batch_size,
                               np.asscalar(np.mean(g_losses))))
                        sys.stdout.flush()
                '''for each epoch'''

                D_A_acc_epochs.append(hit_A / total)
                D_B_acc_epochs.append(hit_B / total)
                G_AB_loss_epochs.append(np.asscalar(np.mean(G_AB_losses)))
                G_BA_loss_epochs.append(np.asscalar(np.mean(G_BA_losses)))
                D_A_loss_epochs.append(np.asscalar(np.mean(D_A_losses)))
                D_B_loss_epochs.append(np.asscalar(np.mean(D_B_losses)))
                G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon)))
                G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon)))
                L_Z_loss_epoches.append(np.asscalar(np.mean(L_Z_losses)))
                d_loss_epochs.append(np.asscalar(np.mean(d_losses)))
                g_loss_epochs.append(np.asscalar(np.mean(g_losses)))

                print(
                    "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins"
                    .format(epoch, np.asscalar(np.mean(d_losses)),
                            0.5 * (hit_A + hit_B) / total,
                            np.asscalar(np.mean(g_losses)),
                            (timer() - start_time) / 60))

                if (epoch + 1) % self.params.print_every == 0:
                    # No need for discriminator weights

                    X_Z = self.X_AE.encode(Variable(en)).data
                    Y_Z = self.Y_AE.encode(Variable(it)).data

                    mstart_time = timer()
                    for method in [self.params.eval_method]:
                        results = get_word_translation_accuracy(
                            self.params.src_lang,
                            src_word2id,
                            X_Z,
                            self.params.tgt_lang,
                            tgt_word2id,
                            Y_Z,
                            method=method,
                            dico_eval=self.params.eval_file)
                        acc1 = results[0][1]

                    print('{} takes {:.2f}s'.format(method,
                                                    timer() - mstart_time))
                    print('Method:{} score:{:.4f}'.format(method, acc1))

                    csls, size = dist_mean_cosine(self.params, X_Z, Y_Z)
                    criterion = size
                    if criterion > best_valid_metric:
                        print("New criterion value: {}".format(criterion))
                        best_valid_metric = criterion
                        fp = open(
                            self.tune_best_dir +
                            "/seed_{}_dico_{}_gate_{}_epoch_{}_acc_{:.3f}.tmp".
                            format(seed, self.params.dico_build,
                                   self.params.gate, epoch, acc1), 'w')
                        fp.close()
                        torch.save(
                            self.X_AE.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}_best_X.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))
                        torch.save(
                            self.Y_AE.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}_best_Y.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))
                        torch.save(
                            self.D_X.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}_best_Dx.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))
                        torch.save(
                            self.D_Y.state_dict(), self.tune_best_dir +
                            '/seed_{}_dico_{}_gate_{}__best_Dy.t7'.format(
                                seed, self.params.dico_build,
                                self.params.gate))

                    # Saving generator weights
                    fp = open(
                        self.tune_dir +
                        "/seed_{}_gate_{}_epoch_{}_acc_{:.3f}.tmp".format(
                            seed, self.params.gate, epoch, acc1), 'w')
                    fp.close()

                    acc_epochs.append(acc1)
                    criterion_epochs.append(criterion)

            criterion_fb, epoch_fb = max([
                (score, index) for index, score in enumerate(criterion_epochs)
            ])
            fp = open(
                self.tune_best_dir +
                "/seed_{}_dico_{}_gate_{}_epoch_{}_Acc_{:.3f}_{:.4f}.cslsfb".
                format(seed, self.params.gate, self.params.dico_build,
                       epoch_fb, acc_epochs[epoch_fb], criterion_fb), 'w')
            fp.close()

            # Save the plot for discriminator accuracy and generator loss
            fig = plt.figure()
            plt.plot(range(0, len(D_A_acc_epochs)),
                     D_A_acc_epochs,
                     color='b',
                     label='D_A')
            plt.plot(range(0, len(D_B_acc_epochs)),
                     D_B_acc_epochs,
                     color='r',
                     label='D_B')
            plt.ylabel('D_accuracy')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_acc.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(D_A_loss_epochs)),
                     D_A_loss_epochs,
                     color='b',
                     label='D_A')
            plt.plot(range(0, len(D_B_loss_epochs)),
                     D_B_loss_epochs,
                     color='r',
                     label='D_B')
            plt.ylabel('D_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_loss_epochs)),
                     G_AB_loss_epochs,
                     color='b',
                     label='G_AB')
            plt.plot(range(0, len(G_BA_loss_epochs)),
                     G_BA_loss_epochs,
                     color='r',
                     label='G_BA')
            plt.ylabel('G_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_recon_epochs)),
                     G_AB_recon_epochs,
                     color='b',
                     label='G_AB')
            plt.plot(range(0, len(G_BA_recon_epochs)),
                     G_BA_recon_epochs,
                     color='r',
                     label='G_BA')
            plt.ylabel('G_recon_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_Recon.png'.format(seed))

            # fig = plt.figure()
            # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z')
            # plt.ylabel('L_Z_loss')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(tune_dir + '/seed_{}_L_Z.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(acc_epochs)),
                     acc_epochs,
                     color='b',
                     label='trans_acc1')
            plt.ylabel('trans_acc')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_trans_acc.png'.format(seed))
            '''
            fig = plt.figure()
            plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls')
            plt.ylabel('csls')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_csls.png'.format(seed))
            '''
            fig = plt.figure()
            plt.plot(range(0, len(g_loss_epochs)),
                     g_loss_epochs,
                     color='b',
                     label='G_loss')
            plt.ylabel('g_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_g_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(d_loss_epochs)),
                     d_loss_epochs,
                     color='b',
                     label='csls')
            plt.ylabel('D_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_d_loss.png'.format(seed))
            plt.close('all')

        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(self.X_AE.state_dict(),
                       self.tune_dir + '/X_AE_model_interrupt.t7')
            torch.save(self.Y_AE.state_dict(),
                       self.tune_dir + '/Y_AE_model_interrupt.t7')
            torch.save(self.D_X.state_dict(),
                       self.tune_dir + '/D_X_model_interrupt.t7')
            torch.save(self.D_Y.state_dict(),
                       self.tune_dir + '/D_y_model_interrupt.t7')
            exit()

        return

    def get_batch_data_fast(self, emb_en, emb_it):

        params = self.params
        random_en_indices = torch.LongTensor(params.mini_batch_size).random_(
            params.most_frequent_sampling_size)
        random_it_indices = torch.LongTensor(params.mini_batch_size).random_(
            params.most_frequent_sampling_size)
        en_batch = to_variable(emb_en)[random_en_indices.cuda()]
        it_batch = to_variable(emb_it)[random_it_indices.cuda()]

        return en_batch, it_batch
Exemple #8
0
            # reset the gradients to avoid gradients accumulation
            discriminator_optim.zero_grad()
            # compute the gradients of loss w.r.t weights
            discriminator_loss.backward(retain_graph=True)
            # update the weights
            discriminator_optim.step()
            # Store the loss for later use
            train_history['discriminator_loss'].append(
                discriminator_loss.item())

            # Train generator
            # create fake labels
            trick = torch.tensor(
                np.array([1] * noise_size),
                dtype=torch.float32).unsqueeze(dim=1).to(device)
            discriminator.eval()  # freeze the discriminator
            if stop == 0:
                generator.train()  # enable training mode for the generator
                generator_loss = generator_criterion(
                    discriminator(generated_data), trick)
                generator_optim.zero_grad()
                generator_loss.backward(retain_graph=True)
                generator_optim.step()
                train_history['generator_loss'].append(generator_loss.item())
            else:
                generator.eval()  # enable evaluation mode
                generator_loss = generator_criterion(
                    discriminator(generated_data), trick)
                train_history['generator_loss'].append(generator_loss.item())

            # unfreeze the discriminator's layers
Exemple #9
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--G_path", help="Generator mdoel path")
    parser.add_argument("--D_path", help="Discriminator mdoel path")
    parser.add_argument("--dir_path",
                        help="path to load LR images and store SR images")
    parser.add_argument("--batch_size", type=int, help="Batch size")
    parser.add_argument("--res_blocks", type=int, help="No. of resnet blocks")
    parser.add_argument("--in_channels",
                        type=int,
                        help="No. of input channels")
    parser.add_argument("--train", type=int, help="Train - 1 or Test - 0")
    parser.add_argument("--downsample",
                        nargs='?',
                        const=True,
                        default=False,
                        help="Downsampling GAN")

    args = parser.parse_args()

    pathG = args.G_path
    pathD = args.D_path
    srDir = args.dir_path

    test_batch_size = args.batch_size
    shuffle_dataset = True
    random_seed = 42

    root = srDir
    dataset = TDF(root, 25, 2, args.downsample, args.train)
    dataset_size = len(dataset)
    print(dataset_size)
    indices = list(range(dataset_size))
    if shuffle_dataset:
        np.random.seed(random_seed)
        np.random.shuffle(indices)

    indices_test = indices

    print(len(indices_test))
    # Creating PT data samplers and loaders:
    test_sampler = SubsetRandomSampler(indices_test)
    test_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=test_batch_size,
                                              sampler=test_sampler)

    # Num batches
    num_batches = len(test_loader)
    print(num_batches)

    G = Generator(args.in_channels, 2, args.res_blocks, args.downsample)
    G.load_state_dict(torch.load(pathG))
    G.eval()
    D = Discriminator(args.in_channels)
    D.load_state_dict(torch.load(pathD))
    D.eval()

    G.cuda()
    D.cuda()

    coords = ["0", "0", "0"]
    with torch.no_grad():
        for index, (lr, hr, filName) in enumerate(test_loader):

            lr = lr.float()
            val_z = Variable(lr)
            val_z = val_z.cuda()
            sr_test = G(val_z)
            val_target = Variable(hr)
            val_target = val_target.cuda()
            hr_test = D(val_target).mean()
            hr_fake = D(sr_test).mean()

            utils.write_voxels(args.batch_size, srDir, sr_test, index,
                               args.downsample, "test", coords, filName)
            if (index) % 50 == 0:
                print(index)

        print(torch.cuda.memory_allocated())
Exemple #10
0
def test(model_path,
         data=(hparams.valid_csv, hparams.dev_file),
         plot_auc='valid',
         plot_path=hparams.result_dir + 'valid',
         best_thresh=None):

    test_dataset = AudioData(data_csv=data[0],
                             data_file=data[1],
                             ds_type='valid',
                             augment=True,
                             transform=transforms.Compose([
                                 transforms.ToTensor(),
                             ]))

    test_loader = DataLoader(test_dataset,
                             batch_size=hparams.batch_size,
                             shuffle=True,
                             num_workers=2)

    discriminator = Discriminator().to(hparams.gpu_device)
    if hparams.cuda:
        discriminator = nn.DataParallel(discriminator,
                                        device_ids=hparams.device_ids)
    checkpoint = torch.load(model_path, map_location=hparams.gpu_device)
    discriminator.load_state_dict(checkpoint['discriminator_state_dict'])

    discriminator = discriminator.eval()
    # print('Model loaded')

    Tensor = torch.cuda.FloatTensor if hparams.cuda else torch.FloatTensor

    print('Testing model on {0} examples. '.format(len(test_dataset)))

    with torch.no_grad():
        pred_logits_list = []
        labels_list = []
        img_names_list = []
        # for _ in range(hparams.repeat_infer):
        for (inp, labels, img_names) in tqdm(test_loader):
            inp = Variable(inp.float(), requires_grad=False)
            labels = Variable(labels.long(), requires_grad=False)

            inp = inp.to(hparams.gpu_device)
            labels = labels.to(hparams.gpu_device)

            if hparams.dim3:
                inp = inp.view(-1, 1, 640, 64)
                inp = torch.cat([inp] * 3, dim=1)

            pred_logits = discriminator(inp)

            pred_logits_list.append(pred_logits)
            labels_list.append(labels)
            img_names_list.append(img_names)

        pred_logits = torch.cat(pred_logits_list, dim=0)
        labels = torch.cat(labels_list, dim=0)

        auc, f1, acc, conf_mat = accuracy_metrics(labels,
                                                  pred_logits,
                                                  plot_auc=plot_auc,
                                                  plot_path=plot_path,
                                                  best_thresh=best_thresh)

        fig = plot_cf(conf_mat)
        plt.savefig(hparams.result_dir + 'test_conf_mat.png')
        res = ' -- avg_acc - {0:.4f}'.format(acc['avg'])
        for it in range(10):
            res += ', acc_{}'.format(
                hparams.id_to_class[it]) + ' - {0:.4f}'.format(acc[it])
        print('== Test on -- ' + model_path + res)
        # print('== Test on -- '+model_path+' == \n\
        #     auc_{0} - {10:.4f}, auc_{1} - {11:.4f}, auc_{2} - {12:.4f}, auc_{3} - {13:.4f}, auc_{4} - {14:.4f}, auc_{5} - {15:.4f}, auc_{6} - {16:.4f}, auc_{7} - {17:.4f}, auc_{8} - {18:.4f}, auc_{9} - {19:.4f}, auc_micro - {20:.4f}, auc_macro - {21:.4f},\n\
        #     acc_{0} - {22:.4f}, acc_{1} - {23:.4f}, acc_{2} - {24:.4f}, acc_{3} - {25:.4f}, acc_{4} - {26:.4f}, acc_{5} - {27:.4f}, acc_{6} - {28:.4f}, acc_{7} - {29:.4f}, acc_{8} - {30:.4f}, acc_{9} - {31:.4f}, acc_avg - {32:.4f},\n\
        #     f1_{0} - {33:.4f}, f1_{1} - {34:.4f}, f1_{2} - {35:.4f}, f1_{3} - {36:.4f}, f1_{4} - {37:.4f}, f1_{5} - {38:.4f}, f1_{6} - {39:.4f}, f1_{7} - {40:.4f}, f1_{8} - {41:.4f}, f1_{9} - {42:.4f}, f1_micro - {42:.4f}, f1_macro - {43:.4f}, =='.\
        #     format([hparams.id_to_class[it] for it in range(10)]+[auc[it] for it in range(10)]+[auc['micro'], auc['macro']]+[acc[it] for it in range(10)]+[acc['avg']]+[f1[it] for it in range(10)]+[f1['micro'], f1['macro']]))
    return acc['avg']
Exemple #11
0
            'Train density Loss: {:.4f} Density Adversarial Loss: {:.4f}  Discriminator Loss: {:.4f}'
            .format(loss_dens_value / iter_count, loss_adv_value / iter_count,
                    loss_D_value / iter_count))
        logger.scalar_summary('Temporal/train_density_loss', loss_dens_value,
                              epoch)
        logger.scalar_summary('Temporal/train_adv_loss', loss_adv_value, epoch)
        logger.scalar_summary('Temporal/train_D_loss', loss_D_value, epoch)
        #test mae & mse on train set
        epoch_mae = running_mae / totalnum
        epoch_mse = np.sqrt(running_mse / totalnum)
        print('Training Iteration:{} MAE: {:.4f} MSE: {:.4f}'.format(
            epoch, epoch_mae, epoch_mse))

        # 验证阶段
        net.eval()
        net_D.eval()

        running_loss = 0.0
        running_mse = 0.0
        running_mae = 0.0
        totalnum = 0
        for idx, (image, densityMap) in enumerate(val_loader):
            image = image.to(device)
            densityMap = densityMap.to(device)

            optimizer.zero_grad()
            duration = time.time()
            predDensityMap = net(image)

            outputs_np = predDensityMap.data.cpu().numpy()
            densityMap_np = densityMap.data.cpu().numpy()
Exemple #12
0
class Solver(object):
    def __init__(self, args):
        # model
        self.g_optimizer = None
        self.d_optimizer = None
        self.generator = None
        self.discriminator = None
        self.MSELoss = None
        self.L1loss = None
        self.GPU_IN_USE = torch.cuda.is_available()
        self.device = torch.device('cuda' if self.GPU_IN_USE else 'cpu')

        # Training settings
        self.dataset = args.dataset
        self.num_epochs = args.num_epochs
        self.batch_size = args.batch_size
        self.threads = args.threads
        self.g_conv_dim = args.g_conv_dim
        self.d_conv_dim = args.d_conv_dim
        self.in_channel = args.in_channel
        self.out_channel = args.out_channel
        self.use_sigmoid = False

        # hyper-parameters
        self.lr = args.lr
        self.beta_1 = args.beta_1
        self.lamb = args.lamb

        # dataloader
        self.training_data_loader = None
        self.testing_data_loader = None

    def build_model(self):
        self.generator = Generator(in_channel=self.in_channel,
                                   out_channel=self.out_channel,
                                   g_conv_dim=self.g_conv_dim,
                                   norm_layer=nn.BatchNorm2d,
                                   use_dropout=False,
                                   n_blocks=9).to(self.device)
        self.generator.normal_init()
        self.discriminator = Discriminator(
            in_channel=self.in_channel + self.out_channel,
            d_conv_dim=self.d_conv_dim,
            num_layers=3,
            norm_layer=nn.BatchNorm2d,
            use_sigmoid=self.use_sigmoid).to(self.device)
        self.discriminator.normal_init()
        self.MSELoss = nn.MSELoss()
        self.L1loss = nn.L1Loss()

        if self.GPU_IN_USE:
            self.MSELoss.cuda()
            self.L1loss.cuda()
            cudnn.benchmark = True

        self.g_optimizer = torch.optim.Adam(self.generator.parameters(),
                                            lr=self.lr,
                                            betas=(self.beta_1, 0.999))
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            lr=self.lr,
                                            betas=(self.beta_1, 0.999))

    def build_dataset(self):
        root_path = "datasets/"
        train_set = get_training_set(root_path + self.dataset)
        test_set = get_test_set(root_path + self.dataset)
        self.training_data_loader = DataLoader(dataset=train_set,
                                               num_workers=self.threads,
                                               batch_size=self.batch_size,
                                               shuffle=True)
        self.testing_data_loader = DataLoader(dataset=test_set,
                                              num_workers=self.threads,
                                              batch_size=self.batch_size,
                                              shuffle=False)

    @staticmethod
    def to_data(x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.d_optimizer.zero_grad()
        self.g_optimizer.zero_grad()

    @staticmethod
    def de_normalize(x):
        """Convert range (-1, 1) to (0, 1)"""
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def checkpoint(self, epoch):
        if not os.path.exists("checkpoint"):
            os.mkdir("checkpoint")
        if not os.path.exists(os.path.join("checkpoint", self.dataset)):
            os.mkdir(os.path.join("checkpoint", self.dataset))
        net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth".format(
            self.dataset, epoch)
        net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth".format(
            self.dataset, epoch)
        torch.save(self.generator, net_g_model_out_path)
        torch.save(self.discriminator, net_d_model_out_path)
        print("Checkpoint saved to {}".format("checkpoint" + self.dataset))

    def mode_switch(self, mode):
        if mode == 'train':
            self.discriminator.train()
            self.generator.train()

        elif mode == 'eval':
            self.discriminator.eval()
            self.generator.eval()

    def train(self):
        self.mode_switch('train')
        for i, (data, target) in enumerate(self.training_data_loader):
            # forward
            data, target = data.to(self.device), target.to(self.device)
            fake_target = self.generator(data)

            ###########################
            # (1) train D network: maximize log(D(x,y)) + log(1 - D(x,G(x)))
            ###########################
            self.reset_grad()

            # train with fake
            fake_combined = torch.cat((data, fake_target), 1)
            fake_prediction = self.discriminator(fake_combined.detach())
            fake_d_loss = self.MSELoss(
                fake_prediction,
                torch.zeros(1,
                            1,
                            fake_prediction.size(2),
                            fake_prediction.size(3),
                            device=self.device))

            # train with real
            real_combined = torch.cat((data, target), 1)
            real_prediction = self.discriminator(real_combined)
            real_d_loss = self.MSELoss(
                real_prediction,
                torch.ones(1,
                           1,
                           real_prediction.size(2),
                           real_prediction.size(3),
                           device=self.device))

            # Combined loss
            loss_d = (fake_d_loss + real_d_loss) * 0.5
            loss_d.backward()
            self.d_optimizer.step()

            ##########################
            # (2) train G network: maximize log(D(x,G(x))) + L1(y,G(x))
            ##########################
            self.reset_grad()
            # First, G(A) should fake the discriminator
            fake_combined = torch.cat((data, fake_target), 1)
            fake_prediction = self.discriminator(fake_combined)
            g_loss_mse = self.MSELoss(
                fake_prediction,
                torch.ones(1,
                           1,
                           fake_prediction.size(2),
                           fake_prediction.size(3),
                           device=self.device))

            # Second, G(A) = B
            g_loss_l1 = self.L1loss(fake_target, target) * self.lamb
            loss_g = g_loss_mse + g_loss_l1
            loss_g.backward()
            self.g_optimizer.step()

            print("({}/{}): Loss_D: {:.4f} Loss_G: {:.4f}".format(
                i, len(self.training_data_loader), loss_d.item(),
                loss_g.item()))

    def test(self):
        self.mode_switch('eval')
        avg_psnr = 0
        with torch.no_grad():
            for (data, target) in self.testing_data_loader:
                data, target = data.to(self.device), target.to(self.device)

                prediction = self.generator(data)
                mse = self.MSELoss(prediction, target)
                psnr = 10 * log10(1 / mse.data[0])
                avg_psnr += psnr

        print("===> Avg. PSNR: {:.4f} dB".format(
            avg_psnr / len(self.testing_data_loader)))

    def run(self):
        self.build_model()
        self.build_dataset()
        for e in range(1, self.num_epochs + 1):
            print("===> Epoch {}/{}".format(e, self.num_epochs))
            self.train()
            self.checkpoint(e)
            self.test()
Exemple #13
0
class CycleBWE(object):
    def __init__(self, params):
        self.params = params
        self.tune_dir = "{}/{}-{}/{}".format(params.exp_id, params.src_lang,
                                             params.tgt_lang,
                                             params.norm_embeddings)
        self.tune_best_dir = "{}/best".format(self.tune_dir)
        self.tune_export_dir = "{}/export".format(self.tune_dir)
        if self.params.eval_file == 'wiki':
            self.eval_file = '../data/bilingual_dicts/{}-{}.5000-6500.txt'.format(
                self.params.src_lang, self.params.tgt_lang)
            self.eval_file2 = '../data/bilingual_dicts/{}-{}.5000-6500.txt'.format(
                self.params.tgt_lang, self.params.src_lang)
        elif self.params.eval_file == 'wacky':
            self.eval_file = '../data/bilingual_dicts/{}-{}.test.txt'.format(
                self.params.src_lang, self.params.tgt_lang)
            self.eval_file2 = '../data/bilingual_dicts/{}-{}.test.txt'.format(
                self.params.tgt_lang, self.params.src_lang)
        else:
            print('Invalid eval file!')
        # self.seed = random.randint(0, 1000)
        # self.seed = 41
        # self.initialize_exp(self.seed)

        self.X_AE = AE(params)
        self.Y_AE = AE(params)
        self.D_X = Discriminator(input_size=params.d_input_size,
                                 hidden_size=params.d_hidden_size,
                                 output_size=params.d_output_size)
        self.D_Y = Discriminator(input_size=params.d_input_size,
                                 hidden_size=params.d_hidden_size,
                                 output_size=params.d_output_size)

        self.nets = [self.X_AE, self.Y_AE, self.D_X, self.D_Y]
        self.loss_fn = torch.nn.BCELoss()
        self.loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    def weights_init(self, m):  # 正交初始化
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.orthogonal(m.weight)
            if m.bias is not None:
                torch.nn.init.constant(m.bias, 0.01)

    def weights_init2(self, m):  # xavier_normal 初始化
        if isinstance(m, torch.nn.Linear):
            torch.nn.init.xavier_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.constant_(m.bias, 0.01)

    def weights_init3(self, m):  # 单位阵初始化
        if isinstance(m, torch.nn.Linear):
            m.weight.data.copy_(
                torch.diag(torch.ones(self.params.g_input_size)))

    def init_state(self, state=1):
        if torch.cuda.is_available():
            # Move the network and the optimizer to the GPU
            for net in self.nets:
                net.cuda()
            self.loss_fn = self.loss_fn.cuda()
            self.loss_fn2 = self.loss_fn2.cuda()

        if self.params.init == 'eye':
            self.X_AE.apply(self.weights_init3)  # 可更改G初始化方式
            self.Y_AE.apply(self.weights_init3)  # 可更改G初始化方式

        elif self.params.init == 'orth':
            self.X_AE.apply(self.weights_init)  # 可更改G初始化方式
            self.Y_AE.apply(self.weights_init)
        else:
            print('Invalid init func!')

        #self.D_X.apply(self.weights_init2)
        #self.D_Y.apply(self.weights_init2)

    def orthogonalize(self, W):
        params = self.params
        W.copy_((1 + params.beta) * W -
                params.beta * W.mm(W.transpose(0, 1).mm(W)))

    def train(self, src_dico, tgt_dico, src_emb, tgt_emb, seed):
        params = self.params
        # Load data
        if not os.path.exists(params.data_dir):
            print("Data path doesn't exists: %s" % params.data_dir)
        if not os.path.exists(self.tune_dir):
            os.makedirs(self.tune_dir)
        if not os.path.exists(self.tune_best_dir):
            os.makedirs(self.tune_best_dir)
        if not os.path.exists(self.tune_export_dir):
            os.makedirs(self.tune_export_dir)

        src_word2id = src_dico[1]
        tgt_word2id = tgt_dico[1]

        en = src_emb
        it = tgt_emb

        params = _get_eval_params(params)
        self.params = params
        eval = Evaluator(params, en, it, torch.cuda.is_available())

        # for seed_index in range(params.num_random_seeds):

        AE_optimizer = optim.SGD(filter(
            lambda p: p.requires_grad,
            list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),
                                 lr=params.g_learning_rate)
        # AE_optimizer = optim.SGD(G_params, lr=0.1, momentum=0.9)
        # AE_optimizer = optim.Adam(G_params, lr=params.g_learning_rate, betas=(0.9, 0.9))
        # AE_optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, list(self.X_AE.parameters()) + list(self.Y_AE.parameters())),lr=params.g_learning_rate,alpha=0.9)
        D_optimizer = optim.SGD(list(self.D_X.parameters()) +
                                list(self.D_Y.parameters()),
                                lr=params.d_learning_rate)
        # D_optimizer = optim.Adam(D_params, lr=params.d_learning_rate, betas=(0.5, 0.9))
        # D_optimizer = optim.RMSprop(list(self.D_X.parameters()) + list(self.D_Y.parameters()), lr=params.d_learning_rate , alpha=0.9)

        # D_X=nn.DataParallel(D_X)
        # D_Y=nn.DataParallel(D_Y)
        # true_dict = get_true_dict(params.data_dir)
        D_A_acc_epochs = []
        D_B_acc_epochs = []
        D_A_loss_epochs = []
        D_B_loss_epochs = []
        G_AB_loss_epochs = []
        G_BA_loss_epochs = []
        G_AB_recon_epochs = []
        G_BA_recon_epochs = []
        L_Z_loss_epoches = []

        acc1_epochs = []
        acc2_epochs = []

        csls_epochs = []
        f_csls_epochs = []
        b_csls_epochs = []
        best_valid_metric = -100

        # logs for plotting later
        log_file = open(
            "log_src_tgt.txt",
            "w")  # Being overwritten in every loop, not really required
        log_file.write("epoch, dis_loss, dis_acc, g_loss\n")

        try:
            for epoch in range(self.params.num_epochs):
                D_A_losses = []
                D_B_losses = []
                G_AB_losses = []
                G_AB_recon = []
                G_BA_losses = []
                G_adv_losses = []
                G_BA_recon = []
                L_Z_losses = []
                d_losses = []
                g_losses = []
                hit_A = 0
                hit_B = 0
                total = 0
                start_time = timer()
                # lowest_loss = 1e5
                # label_D = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_())
                label_D = to_variable(
                    torch.FloatTensor(2 * params.mini_batch_size).zero_())
                label_D[:params.mini_batch_size] = 1 - params.smoothing
                label_D[params.mini_batch_size:] = params.smoothing

                label_G = to_variable(
                    torch.FloatTensor(params.mini_batch_size).zero_())
                label_G = label_G + 1 - params.smoothing

                for mini_batch in range(
                        0, params.iters_in_epoch // params.mini_batch_size):
                    for d_index in range(params.d_steps):
                        D_optimizer.zero_grad()  # Reset the gradients
                        self.D_X.train()
                        self.D_Y.train()

                        #print('D_X:', self.D_X.map1.weight.data)
                        #print('D_Y:', self.D_Y.map1.weight.data)

                        view_X, view_Y = self.get_batch_data_fast_new(en, it)
                        # Discriminator X
                        #print('View_Y',view_Y)
                        fake_X = self.Y_AE.encode(view_Y).detach()
                        #print('fakeX',fake_X)
                        input = torch.cat([view_X, fake_X], 0)

                        pred_A = self.D_X(input)
                        #print('Pred_A',pred_A)
                        D_A_loss = self.loss_fn(pred_A, label_D)
                        # print(view_Y)
                        # Discriminator Y
                        # print('View_X',view_X)
                        fake_Y = self.X_AE.encode(view_X).detach()
                        # print('fakeY:',fake_Y)

                        input = torch.cat([view_Y, fake_Y], 0)
                        pred_B = self.D_Y(input)
                        # print('Pred_B', pred_B)
                        D_B_loss = self.loss_fn(pred_B, label_D)

                        D_loss = (1.0) * D_A_loss + params.gate * D_B_loss

                        D_loss.backward(
                        )  # compute/store gradients, but don't change params
                        d_losses.append(to_numpy(D_loss.data))
                        D_A_losses.append(to_numpy(D_A_loss.data))
                        D_B_losses.append(to_numpy(D_B_loss.data))

                        discriminator_decision_A = to_numpy(pred_A.data)
                        hit_A += np.sum(
                            discriminator_decision_A[:params.mini_batch_size]
                            >= 0.5)
                        hit_A += np.sum(
                            discriminator_decision_A[params.mini_batch_size:] <
                            0.5)

                        discriminator_decision_B = to_numpy(pred_B.data)
                        hit_B += np.sum(
                            discriminator_decision_B[:params.mini_batch_size]
                            >= 0.5)
                        hit_B += np.sum(
                            discriminator_decision_B[params.mini_batch_size:] <
                            0.5)

                        D_optimizer.step(
                        )  # Only optimizes D's parameters; changes based on stored gradients from backward()

                        # Clip weights
                        _clip(self.D_X, params.clip_value)
                        _clip(self.D_Y, params.clip_value)
                        # print('D_loss',d_losses)

                        sys.stdout.write(
                            "[%d/%d] :: Discriminator Loss: %.3f \r" %
                            (mini_batch,
                             params.iters_in_epoch // params.mini_batch_size,
                             np.asscalar(np.mean(d_losses))))
                        sys.stdout.flush()

                    total += 2 * params.mini_batch_size * params.d_steps

                    for g_index in range(params.g_steps):
                        # 2. Train G on D's response (but DO NOT train D on these labels)
                        AE_optimizer.zero_grad()
                        self.D_X.eval()
                        self.D_Y.eval()
                        view_X, view_Y = self.get_batch_data_fast_new(en, it)

                        # Generator X_AE
                        ## adversarial loss
                        Y_fake = self.X_AE.encode(view_X)
                        # X_recon = self.X_AE.decode(X_Z)
                        # Y_fake = self.Y_AE.encode(X_Z)
                        pred_Y = self.D_Y(Y_fake)
                        L_adv_X = self.loss_fn(pred_Y, label_G)

                        X_Cycle = self.Y_AE.encode(Y_fake)
                        L_Cycle_X = 1.0 - torch.mean(
                            self.loss_fn2(view_X, X_Cycle))

                        # L_recon_X = 1.0 - torch.mean(self.loss_fn2(view_X, X_recon))
                        # L_G_AB = L_adv_X + params.recon_weight * L_recon_X

                        # Generator Y_AE
                        # adversarial loss
                        X_fake = self.Y_AE.encode(view_Y)
                        pred_X = self.D_X(X_fake)
                        L_adv_Y = self.loss_fn(pred_X, label_G)

                        ### Cycle Loss
                        Y_Cycle = self.X_AE.encode(X_fake)
                        L_Cycle_Y = 1.0 - torch.mean(
                            self.loss_fn2(view_Y, Y_Cycle))

                        # L_recon_Y = 1.0 - torch.mean(self.loss_fn2(view_Y, Y_recon))
                        # L_G_BA = L_adv_Y + params.recon_weight * L_recon_Y
                        # L_Z = 1.0 - torch.mean(self.loss_fn2(X_Z, Y_Z))

                        # G_loss = L_G_AB + L_G_BA + L_Z
                        G_loss = params.adv_weight * ( params.gate * L_adv_X + (1.0) * L_adv_Y) + \
                                 params.cycle_weight * (L_Cycle_X+L_Cycle_Y)

                        G_loss.backward()

                        g_losses.append(to_numpy(G_loss.data))
                        G_AB_losses.append(to_numpy(L_adv_X.data))
                        G_BA_losses.append(to_numpy(L_adv_Y.data))
                        G_adv_losses.append(to_numpy(L_adv_Y.data))
                        G_AB_recon.append(to_numpy(L_Cycle_X.data))
                        G_BA_recon.append(to_numpy(L_Cycle_Y.data))

                        AE_optimizer.step()  # Only optimizes G's parameters
                        self.orthogonalize(self.X_AE.map1.weight.data)
                        self.orthogonalize(self.Y_AE.map1.weight.data)

                        sys.stdout.write(
                            "[%d/%d] ::                                     Generator Loss: %.3f \r"
                            % (mini_batch,
                               params.iters_in_epoch // params.mini_batch_size,
                               np.asscalar(np.mean(g_losses))))
                        sys.stdout.flush()
                '''for each epoch'''
                D_A_acc_epochs.append(hit_A / total)
                D_B_acc_epochs.append(hit_B / total)
                G_AB_loss_epochs.append(np.asscalar(np.mean(G_AB_losses)))
                G_BA_loss_epochs.append(np.asscalar(np.mean(G_BA_losses)))
                D_A_loss_epochs.append(np.asscalar(np.mean(D_A_losses)))
                D_B_loss_epochs.append(np.asscalar(np.mean(D_B_losses)))
                G_AB_recon_epochs.append(np.asscalar(np.mean(G_AB_recon)))
                G_BA_recon_epochs.append(np.asscalar(np.mean(G_BA_recon)))
                # L_Z_loss_epoches.append(np.asscalar(np.mean(L_Z_losses)))

                print(
                    "Epoch {} : Discriminator Loss: {:.3f}, Discriminator Accuracy: {:.3f}, Generator Loss: {:.3f}, Time elapsed {:.2f} mins"
                    .format(epoch, np.asscalar(np.mean(d_losses)),
                            0.5 * (hit_A + hit_B) / total,
                            np.asscalar(np.mean(g_losses)),
                            (timer() - start_time) / 60))

                # lr decay
                # g_optim_state = AE_optimizer.state_dict()
                # old_lr = g_optim_state['param_groups'][0]['lr']
                # g_optim_state['param_groups'][0]['lr'] = max(old_lr * params.lr_decay, params.lr_min)
                # AE_optimizer.load_state_dict(g_optim_state)
                # print("Changing the learning rate: {} -> {}".format(old_lr, g_optim_state['param_groups'][0]['lr']))
                # d_optim_state = D_optimizer.state_dict()
                # d_optim_state['param_groups'][0]['lr'] = max(
                #     d_optim_state['param_groups'][0]['lr'] * params.lr_decay, params.lr_min)
                # D_optimizer.load_state_dict(d_optim_state)
                #     d_optim_state['param_groups'][0]['lr'] * params.lr_decay, params.lr_min)
                # D_optimizer.load_state_dict(d_optim_state)

                if (epoch + 1) % params.print_every == 0:
                    # No need for discriminator weights
                    # torch.save(d.state_dict(), 'discriminator_weights_en_es_{}.t7'.format(epoch))

                    # all_precisions = eval.get_all_precisions(G_AB(src_emb.weight).data)
                    Vec_xy = self.X_AE.encode(Variable(en))
                    Vec_xyx = self.Y_AE.encode(Vec_xy)
                    Vec_yx = self.Y_AE.encode(Variable(it))
                    Vec_yxy = self.X_AE.encode(Vec_yx)

                    mstart_time = timer()

                    # for method in ['csls_knn_10']:
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.src_lang,
                            src_word2id,
                            Vec_xy.data,
                            params.tgt_lang,
                            tgt_word2id,
                            it,
                            method=method,
                            dico_eval=self.eval_file,
                            device=params.cuda_device)
                        acc1 = results[0][1]
                        results = get_word_translation_accuracy(
                            params.tgt_lang,
                            tgt_word2id,
                            Vec_yx.data,
                            params.src_lang,
                            src_word2id,
                            en,
                            method=method,
                            dico_eval=self.eval_file2,
                            device=params.cuda_device)
                        acc2 = results[0][1]
                        print('{} takes {:.2f}s'.format(
                            method,
                            timer() - mstart_time))
                        print('Method:{} test_score:{:.4f}-{:.4f}'.format(
                            method, acc1, acc2))
                    '''
                    # for method in ['csls_knn_10']:
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.src_lang, src_word2id, Vec_xyx.data,
                            params.src_lang, src_word2id, en,
                            method=method,
                            dico_eval='/data/dictionaries/{}-{}.wacky.dict'.format(params.src_lang,params.src_lang),
                            device=params.cuda_device
                        )
                        acc11 = results[0][1]
                    # for method in ['csls_knn_10']:
                    for method in [params.eval_method]:
                        results = get_word_translation_accuracy(
                            params.tgt_lang, tgt_word2id, Vec_yxy.data,
                            params.tgt_lang, tgt_word2id, it,
                            method=method,
                            dico_eval='/data/dictionaries/{}-{}.wacky.dict'.format(params.tgt_lang,params.tgt_lang),
                            device=params.cuda_device
                        )
                        acc22 = results[0][1]
                    print('Valid:{} score:{:.4f}-{:.4f}'.format(method, acc11, acc22))
                    avg_valid = (acc11+acc22)/2.0
                    # valid_x = torch.mean(self.loss_fn2(en, Vec_xyx.data))
                    # valid_y = torch.mean(self.loss_fn2(it, Vec_yxy.data))
                    # avg_valid = (valid_x+valid_y)/2.0
                    '''
                    # csls = 0
                    f_csls = eval.dist_mean_cosine(Vec_xy.data, it)
                    b_csls = eval.dist_mean_cosine(Vec_yx.data, en)
                    csls = (f_csls + b_csls) / 2.0
                    # csls = eval.calc_unsupervised_criterion(X_Z)
                    if csls > best_valid_metric:
                        print("New csls value: {}".format(csls))
                        best_valid_metric = csls
                        fp = open(
                            self.tune_dir +
                            "/best/seed_{}_dico_{}_epoch_{}_acc_{:.3f}-{:.3f}.tmp"
                            .format(seed, params.dico_build, epoch, acc1,
                                    acc2), 'w')
                        fp.close()
                        torch.save(
                            self.X_AE.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_X.t7'.format(
                                seed, params.dico_build))
                        torch.save(
                            self.Y_AE.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_Y.t7'.format(
                                seed, params.dico_build))
                        torch.save(
                            self.D_X.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_Dx.t7'.format(
                                seed, params.dico_build))
                        torch.save(
                            self.D_Y.state_dict(), self.tune_dir +
                            '/best/seed_{}_dico_{}_best_Dy.t7'.format(
                                seed, params.dico_build))
                    # print(json.dumps(all_precisions))
                    # p_1 = all_precisions['validation']['adv']['without-ref']['nn'][1]
                    # p_1 = all_precisions['validation']['adv']['without-ref']['csls'][1]
                    # log_file.write(str(results) + "\n")
                    # print('Method: nn score:{:.4f}'.format(acc))
                    # Saving generator weights
                    # torch.save(X_AE.state_dict(), tune_dir+'/G_AB_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(seed,params.most_frequent_sampling_size,params.g_learning_rate,acc))
                    # torch.save(Y_AE.state_dict(), tune_dir+'/G_BA_seed_{}_mf_{}_lr_{}_p@1_{:.3f}.t7'.format(seed,params.most_frequent_sampling_size,params.g_learning_rate,acc))
                    fp = open(
                        self.tune_dir +
                        "/seed_{}_epoch_{}_acc_{:.3f}-{:.3f}_valid_{:.4f}.tmp".
                        format(seed, epoch, acc1, acc2, csls), 'w')
                    fp.close()
                    acc1_epochs.append(acc1)
                    acc2_epochs.append(acc2)
                    csls_epochs.append(csls)
                    f_csls_epochs.append(f_csls)
                    b_csls_epochs.append(b_csls)

            csls_fb, epoch_fb = max([
                (score, index) for index, score in enumerate(csls_epochs)
            ])
            fp = open(
                self.tune_dir +
                "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsfb".format(
                    seed, epoch_fb, acc1_epochs[epoch_fb],
                    acc2_epochs[epoch_fb], csls_fb), 'w')
            fp.close()
            csls_f, epoch_f = max([
                (score, index) for index, score in enumerate(f_csls_epochs)
            ])
            fp = open(
                self.tune_dir +
                "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsf".format(
                    seed, epoch_f, acc1_epochs[epoch_f], acc2_epochs[epoch_f],
                    csls_f), 'w')
            fp.close()
            csls_b, epoch_b = max([
                (score, index) for index, score in enumerate(b_csls_epochs)
            ])
            fp = open(
                self.tune_dir +
                "/best/seed_{}_epoch_{}_{:.3f}_{:.3f}_{:.3f}.cslsb".format(
                    seed, epoch_b, acc1_epochs[epoch_b], acc2_epochs[epoch_b],
                    csls_b), 'w')
            fp.close()
            '''

            # Save the plot for discriminator accuracy and generator loss
            fig = plt.figure()
            plt.plot(range(0, len(D_A_acc_epochs)), D_A_acc_epochs, color='b', label='D_A')
            plt.plot(range(0, len(D_B_acc_epochs)), D_B_acc_epochs, color='r', label='D_B')
            plt.ylabel('D_accuracy')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_acc.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(D_A_loss_epochs)), D_A_loss_epochs, color='b', label='D_A')
            plt.plot(range(0, len(D_B_loss_epochs)), D_B_loss_epochs, color='r', label='D_B')
            plt.ylabel('D_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_D_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_loss_epochs)), G_AB_loss_epochs, color='b', label='G_AB')
            plt.plot(range(0, len(G_BA_loss_epochs)), G_BA_loss_epochs, color='r', label='G_BA')
            plt.ylabel('G_losses')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(G_AB_recon_epochs)), G_AB_recon_epochs, color='b', label='G_AB')
            plt.plot(range(0, len(G_BA_recon_epochs)), G_BA_recon_epochs, color='r', label='G_BA')
            plt.ylabel('G_Cycle_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_G_Cycle.png'.format(seed))

            # fig = plt.figure()
            # plt.plot(range(0, len(L_Z_loss_epoches)), L_Z_loss_epoches, color='b', label='L_Z')
            # plt.ylabel('L_Z_loss')
            # plt.xlabel('epochs')
            # plt.legend()
            # fig.savefig(tune_dir + '/seed_{}_stage_{}_L_Z.png'.format(seed,stage))

            fig = plt.figure()
            plt.plot(range(0, len(acc1_epochs)), acc1_epochs, color='b', label='trans_acc1')
            plt.plot(range(0, len(acc2_epochs)), acc2_epochs, color='r', label='trans_acc2')
            plt.ylabel('trans_acc')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_trans_acc.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(csls_epochs)), csls_epochs, color='b', label='csls')
            plt.plot(range(0, len(f_csls_epochs)), f_csls_epochs, color='r', label='csls_f')
            plt.plot(range(0, len(b_csls_epochs)), b_csls_epochs, color='g', label='csls_b')
            plt.ylabel('csls')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_csls.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(g_losses)), g_losses, color='b', label='G_loss')
            plt.ylabel('g_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_g_loss.png'.format(seed))

            fig = plt.figure()
            plt.plot(range(0, len(d_losses)), d_losses, color='b', label='csls')
            plt.ylabel('D_loss')
            plt.xlabel('epochs')
            plt.legend()
            fig.savefig(self.tune_dir + '/seed_{}_d_loss.png'.format(seed))
            plt.close('all')
            '''

        except KeyboardInterrupt:
            print("Interrupted.. saving model !!!")
            torch.save(self.X_AE.state_dict(), 'g_model_interrupt.t7')
            torch.save(self.D_X.state_dict(), 'd_model_interrupt.t7')
            log_file.close()
            exit()

        log_file.close()
        return self.X_AE

    def get_batch_data_fast_new(self, emb_en, emb_it):

        params = self.params
        random_en_indices = torch.LongTensor(params.mini_batch_size).random_(
            params.most_frequent_sampling_size)
        random_it_indices = torch.LongTensor(params.mini_batch_size).random_(
            params.most_frequent_sampling_size)
        #print(random_en_indices)
        #print(random_it_indices)
        en_batch = to_variable(emb_en)[random_en_indices.cuda()]
        it_batch = to_variable(emb_it)[random_it_indices.cuda()]
        return en_batch, it_batch

    def export(self,
               src_dico,
               tgt_dico,
               emb_en,
               emb_it,
               seed,
               export_emb=False):
        params = _get_eval_params(self.params)
        eval = Evaluator(params, emb_en, emb_it, torch.cuda.is_available())
        # Export adversarial dictionaries
        optim_X_AE = AE(params).cuda()
        optim_Y_AE = AE(params).cuda()
        print('Loading pre-trained models...')
        optim_X_AE.load_state_dict(
            torch.load(self.tune_dir +
                       '/best/seed_{}_dico_{}_best_X.t7'.format(
                           seed, params.dico_build)))
        optim_Y_AE.load_state_dict(
            torch.load(self.tune_dir +
                       '/best/seed_{}_dico_{}_best_Y.t7'.format(
                           seed, params.dico_build)))
        X_Z = optim_X_AE.encode(Variable(emb_en)).data
        Y_Z = optim_Y_AE.encode(Variable(emb_it)).data

        mstart_time = timer()
        for method in ['nn', 'csls_knn_10']:
            results = get_word_translation_accuracy(params.src_lang,
                                                    src_dico[1],
                                                    X_Z,
                                                    params.tgt_lang,
                                                    tgt_dico[1],
                                                    emb_it,
                                                    method=method,
                                                    dico_eval=self.eval_file,
                                                    device=params.cuda_device)
            acc1 = results[0][1]
            results = get_word_translation_accuracy(params.tgt_lang,
                                                    tgt_dico[1],
                                                    Y_Z,
                                                    params.src_lang,
                                                    src_dico[1],
                                                    emb_en,
                                                    method=method,
                                                    dico_eval=self.eval_file2,
                                                    device=params.cuda_device)
            acc2 = results[0][1]

            # csls = 0
            print('{} takes {:.2f}s'.format(method, timer() - mstart_time))
            print('Method:{} score:{:.4f}-{:.4f}'.format(method, acc1, acc2))

        f_csls = eval.dist_mean_cosine(X_Z, emb_it)
        b_csls = eval.dist_mean_cosine(Y_Z, emb_en)
        csls = (f_csls + b_csls) / 2.0
        print("Seed:{},ACC:{:.4f}-{:.4f},CSLS_FB:{:.6f}".format(
            seed, acc1, acc2, csls))
        #'''
        print('Building dictionaries...')
        params.dico_build = "S2T&T2S"
        params.dico_method = "csls_knn_10"
        X_Z = X_Z / X_Z.norm(2, 1, keepdim=True).expand_as(X_Z)
        emb_it = emb_it / emb_it.norm(2, 1, keepdim=True).expand_as(emb_it)
        f_dico_induce = build_dictionary(X_Z, emb_it, params)
        f_dico_induce = f_dico_induce.cpu().numpy()
        Y_Z = Y_Z / Y_Z.norm(2, 1, keepdim=True).expand_as(Y_Z)
        emb_en = emb_en / emb_en.norm(2, 1, keepdim=True).expand_as(emb_en)
        b_dico_induce = build_dictionary(Y_Z, emb_en, params)
        b_dico_induce = b_dico_induce.cpu().numpy()

        f_dico_set = set([(a, b) for a, b in f_dico_induce])
        b_dico_set = set([(b, a) for a, b in b_dico_induce])

        intersect = list(f_dico_set & b_dico_set)
        union = list(f_dico_set | b_dico_set)

        with io.open(
                self.tune_dir +
                '/export/{}-{}.dict'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in f_dico_induce:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.dict'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in b_dico_induce:
                f.write('{} {}\n'.format(tgt_dico[0][item[0]],
                                         src_dico[0][item[1]]))

        with io.open(self.tune_dir + '/export/{}-{}.intersect'.format(
                params.src_lang, params.tgt_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(self.tune_dir + '/export/{}-{}.intersect'.format(
                params.tgt_lang, params.src_lang),
                     'w',
                     encoding='utf-8',
                     newline='\n') as f:
            for item in intersect:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.union'.format(params.src_lang, params.tgt_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(src_dico[0][item[0]],
                                         tgt_dico[0][item[1]]))

        with io.open(
                self.tune_dir +
                '/export/{}-{}.union'.format(params.tgt_lang, params.src_lang),
                'w',
                encoding='utf-8',
                newline='\n') as f:
            for item in union:
                f.write('{} {}\n'.format(tgt_dico[0][item[1]],
                                         src_dico[0][item[0]]))

        if export_emb:
            print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang,
                                              params.src_lang))
            loader.export_embeddings(
                src_dico[0],
                X_Z,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.src_lang, params.tgt_lang, params.src_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.src_lang, params.tgt_lang,
                                              params.tgt_lang))
            loader.export_embeddings(
                tgt_dico[0],
                emb_it,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.src_lang, params.tgt_lang, params.tgt_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang,
                                              params.tgt_lang))
            loader.export_embeddings(
                tgt_dico[0],
                Y_Z,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.tgt_lang, params.src_lang, params.tgt_lang),
                eformat='txt')
            print('Exporting {}-{}.{}'.format(params.tgt_lang, params.src_lang,
                                              params.src_lang))
            loader.export_embeddings(
                src_dico[0],
                emb_en,
                path=self.tune_dir + '/export/{}-{}.{}'.format(
                    params.tgt_lang, params.src_lang, params.src_lang),
                eformat='txt')
Exemple #14
0
    model = Discriminator(args.length, len(species)).to(device)
    optimizer = optim.Adam(model.parameters(), lr=args.rate)

    # raise an error if receptive field is smaller than sampling length
    if args.length < receptive_field(model):
        raise Exception("Input sequences must be longer than {} bp.".format(
            receptive_field(model)))

    for epoch in range(args.epoch):
        train(model, device, loader, optimizer, epoch + 1)
        print("")

    # calculate style matrices
    if args.verbose > 1: print("Extracting style matrices...")

    model.eval()
    style_matrices = []

    for record in SeqIO.parse(args.contig, "fasta"):
        tensor = to_tensor(str(record.seq))
        style_matrices += model.get_style(tensor.float().to(device),
                                          args.layer)

    style_matrices = torch.cat(style_matrices, dim=0)

    torch.save(style_matrices, args.output)

    if args.verbose > 1:
        print("Genome style matrix is successfully written to {}.".format(
            args.output))
def train():
    args=Config
    generator = Generator(args)
    discriminator = Discriminator(args)
    feat_extractor = get_feat_extractor()
    dataset = SRDataset(dataset_path=args['train_set_path'],hr_size=args['hr_size'], scale_factor=args['scale'])
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=Config['batch_size'],
                    shuffle=True)
    test_dataset = SRDataset(args['test_set_path'],args['hr_size'], args['scale'])
    test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4,
                    shuffle=True)

    if Config['optimizer']=='Adam':
        gen_optimizer = torch.optim.Adam(generator.parameters(), lr = Config['lr'])
        disc_optimizer = torch.optim.Adam(discriminator.parameters(), lr = Config['lr'])
    else:
        gen_optimizer = torch.optim.SGD(generator.parameters(), lr = Config['lr'])
        disc_optimizer = torch.optim.SGD(discriminator.parameters(), lr = Config['lr'])

    if Config['tensorboard_log']:
        writer = SummaryWriter(Config['checkpoint_path'])

    for epoch in tqdm(range(Config['epochs'])):
        generator.train()
        discriminator.train()
        for lr, hr in data_loader:
            valid = torch.zeros((lr.shape[0],1), requires_grad=False)
            fake = torch.ones((lr.shape[0],1), requires_grad=False)
            # print(lr.shape)
            sr = generator(lr)

            d_fake = discriminator(sr)
            d_real = discriminator(hr)

            c_loss = content_loss(args, feat_extractor, hr, sr)
            adv_loss = 1e-3 * nn.BCELoss()(valid, d_fake)
            mse_loss = nn.MSELoss()(hr, sr)
            perceptual_loss = c_loss + adv_loss + mse_loss

            valid_loss = nn.BCELoss()(valid, d_real)
            fake_loss = nn.BCELoss()(fake, d_fake)
            d_loss = valid_loss + fake_loss

            perceptual_loss.backward()
            d_loss.backward()

            gen_optimizer.step()
            disc_optimizer.step()
        generator.eval()
        discriminator.eval()
        test_lr, test_hr = next(iter(test_data_loader))
        with torch.set_grad_enabled(False):
            test_sr = generator(sr)
            for i in range(test_sr.shape[0]):
                img_sr = test_sr[i]
                img_hr = test_hr[i]
                img_lr = test_lr[i]
                save_image(img_sr, 'img_sr_%d.png'%i)
                save_image(img_hr, 'img_hr_%d.png'%i)
                save_image(img_lr, 'img_lr_%d.png'%i)

        print(f'Epoch {epoch}: Perceptual Loss:{perceptual_loss:.4f}, Disc Loss:{d_loss:.4f}')
    torch.save({'generator':generator,
                'discriminator':discriminator},
                os.path.join(Config['checkpoint_path'],'model.pth'))
Exemple #16
0
            # ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #

            G_loss = criterion(fake_predict, torch.ones_like(fake_predict))
            # ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
            gen.zero_grad()
            G_loss.backward()
            gen_optim.step()

        # ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #

        if batch_idx == 0:
            print(f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
                            Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}")

            with torch.no_grad():
                disc.eval()
                gen.eval()
                fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
                data = real.reshape(-1, CHANNELS, H, W)
                if BATCH_SIZE > 32:
                    fake = fake[:32]
                    data = data[:32]
                img_grid_fake = torchvision.utils.make_grid(fake,
                                                            normalize=True)
                img_grid_real = torchvision.utils.make_grid(data,
                                                            normalize=True)

                writer_fake.add_image("Mnist Fake Images",
                                      img_grid_fake,
                                      global_step=step)
                writer_real.add_image("Mnist Real Images",
Exemple #17
0
class StarGAN(nn.Module):
    def __init__(self, config, train_loader, test_loader):
        super(StarGAN, self).__init__()

        self.config = config
        self.train_loader = train_loader
        self.test_loader = test_loader

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.test_source, self.test_domain, _ = next(iter(self.test_loader))
        self.test_source = self.test_source.to(self.device)
        self.test_domain = self.test_domain.view(-1, 1, 1).to(self.device)
        self.test_batch_size, _, self.height, self.width = self.test_source.size(
        )
        self.save_img_cnt = 0
        self.loss = {}
        self.items = {}

        self.iter_size = len(self.train_loader)
        self.epoch_size = config['max_iter'] // self.iter_size + 1

        lr = config['lr']
        lr_F = config['lr_F']
        beta1 = config['beta1']
        beta2 = config['beta2']
        init = config['init']
        # weight_decay = config['weight_decay']

        self.batch_size = config['batch_size']
        self.gan_type = config['gan_type']
        self.max_iter = config['max_iter']
        self.img_size = config['crop_size']

        self.path_sample = os.path.join('./results/', config['save_name'],
                                        "samples")
        self.path_model = os.path.join('./results/', config['save_name'],
                                       "models")

        self.w_style = config['w_style']
        self.w_ds = config['w_ds']
        self.w_cyc = config['w_cyc']
        self.w_regul = config['w_regul']

        self.num_domain = len(train_loader.dataset.domains)
        self.dim_style = config['dim_style']
        self.dim_latent = config['mapping_network']['dim_latent']

        self.generator = Generator(config['gen'])  # 29072960
        # self.generator = DummyModel(config['gen'])  # 29072960
        self.style_encoder = StyleEncoder(config['style_encoder'],
                                          self.num_domain, self.img_size)
        self.mapping_network = MappingNetwork(config['mapping_network'],
                                              self.num_domain, self.dim_style)
        self.discriminator = Discriminator(config['dis'], self.num_domain,
                                           self.img_size)

        self.optimizer_d = torch.optim.Adam(self.discriminator.parameters(),
                                            lr, (beta1, beta2))
        params_g = list(self.generator.parameters()) + list(
            self.style_encoder.parameters())
        self.optimizer_g = torch.optim.Adam(params_g, lr, (beta1, beta2))
        self.optimizer_g.add_param_group({
            'params':
            self.mapping_network.parameters(),
            'lr':
            lr_F,
            'betas': (beta1, beta2),
        })

        # self.scheduler_g = get_scheduler(self.optimizer_g, config)
        # self.scheduler_d = get_scheduler(self.optimizer_d, config)

        self.apply(weights_init(init))

        self.criterion_l1 = nn.L1Loss()
        self.criterion_l2 = nn.MSELoss()
        self.criterion_bce = nn.BCEWithLogitsLoss()

        self.to(self.device)

    # def update_scheduler(self):
    #     if self.current_epoch >= 10 and self.scheduler_d and self.scheduler_g:
    #         self.scheduler_d.step()
    #         self.scheduler_g.step()

    def calc_adversarial_loss(self, logit, is_real):
        if self.gan_type == 'bce':
            target_fn = torch.ones_like if is_real else torch.zeros_like
            loss = self.criterion_bce(logit, target_fn(logit))

        elif self.gan_type == 'lsgan':
            target_fn = torch.ones_like if is_real else torch.zeros_like
            loss = self.criterion_l2(logit, target_fn(logit))

        elif self.gan_type == 'wgan':
            if is_real:
                loss = -torch.mean(logit)
            else:
                loss = torch.mean(logit)
        else:
            raise NotImplementedError("Unsupported gan type: {}".format(
                self.gan_type))

        return loss

    def calc_r1(self, real_images, logit_real):
        grad_real = autograd.grad(outputs=logit_real.sum(),
                                  inputs=real_images,
                                  create_graph=True)[0]
        grad_penalty = (grad_real.view(grad_real.size(0),
                                       -1).norm(2, dim=1)**2).mean()
        grad_penalty = 0.5 * grad_penalty
        return grad_penalty

    def calc_gp(self, real_images, fake_images):  # TODO :
        raise NotImplementedError("")
        alpha = torch.rand(real_images.size(0), 1, 1, 1).to(self.device)
        interpolated = (alpha * real_images +
                        ((1 - alpha) * fake_images)).requires_grad_(True)
        prob_interpolated, _ = self.discriminator(interpolated)

        grad_outputs = torch.ones(prob_interpolated.size()).to(self.device)
        gradients = torch.autograd.grad(outputs=prob_interpolated,
                                        inputs=interpolated,
                                        grad_outputs=grad_outputs,
                                        create_graph=True,
                                        retain_graph=True)[0]

        gradients = gradients.reshape(gradients.size(0), -1)
        gradient_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
        return gradient_penalty

    def generate_random_nosie(self):
        random_noise = torch.randn(1, self.dim_latent).to(self.device)
        random_domain = torch.randint(self.num_domain,
                                      (self.batch_size, 1, 1)).to(self.device)
        return random_noise, random_domain

    def eval_mode_all(self):
        self.discriminator.eval()
        self.generator.eval()

    def update_d(self, real, real_domain, random_noise, random_domain):
        reset_gradients([self.optimizer_g, self.optimizer_d])
        real.requires_grad = True

        style_mapped = self.mapping_network(random_noise, random_domain)
        fake = self.generator(real, style_mapped)

        # Adv
        logit_real = self.discriminator(real, real_domain)
        logit_fake = self.discriminator(fake.detach(), random_domain)

        adv_d_real = self.calc_adversarial_loss(logit_real,
                                                is_real=True)  # .contiguous()
        adv_d_fake = self.calc_adversarial_loss(logit_fake,
                                                is_real=False)  # .contiguous()

        if self.config['gan_type'] == 'bce':
            regul = self.calc_r1(real, logit_real) * self.w_regul
        elif self.config['gan_type'] == 'wgan':
            regul = self.calc_gp(real, fake) * self.w_regul

        self.adv_d_fake = adv_d_fake
        self.adv_d_real = adv_d_real
        loss_d = adv_d_fake + adv_d_real + regul
        loss_d.backward()
        self.optimizer_d.step()

        self.loss['adv_d_fake'] = adv_d_fake.item()
        self.loss['adv_d_real'] = adv_d_real.item()
        self.loss['regul'] = regul.item()

        self.items["logit_real"] = logit_real
        self.items["logit_fake_d"] = logit_fake

    def update_g(self, real, real_domain, random_noise, random_domain):
        reset_gradients([self.optimizer_g, self.optimizer_d])

        style_fake = self.mapping_network(random_noise, random_domain)
        style_real = self.style_encoder(real, real_domain)
        fake = self.generator(real, style_fake)
        style_recon = self.style_encoder(fake, random_domain)
        image_recon = self.generator(fake, style_real)

        # Adversarial
        logit_fake = self.discriminator(fake, random_domain)
        adv_g = self.calc_adversarial_loss(logit_fake, is_real=True)

        # Style recon
        style_recon_loss = self.criterion_l1(style_fake,
                                             style_recon) * self.w_style

        # Style diversification
        random_noise1 = torch.randn(1, self.dim_latent).to(self.device)
        random_noise2 = torch.randn(1, self.dim_latent).to(self.device)
        random_domain1 = torch.randint(self.num_domain,
                                       (self.batch_size, 1, 1)).to(self.device)

        s1 = self.mapping_network(random_noise1, random_domain1)
        s2 = self.mapping_network(random_noise2, random_domain1)
        fake1 = self.generator(real, s1)
        fake2 = self.generator(real, s2)

        ds_loss = -self.criterion_l1(fake1, fake2) * self.w_ds

        # Cycle consistency
        cyc_loss = self.criterion_l1(real, image_recon) * self.w_cyc

        loss_g = adv_g + cyc_loss + style_recon_loss + ds_loss
        loss_g.backward()
        self.optimizer_g.step()

        self.loss['adv_g'] = adv_g.item()
        self.loss['style_recon_loss'] = style_recon_loss.item()
        self.loss['ds_loss'] = ds_loss.item()
        self.loss['cyc_loss'] = cyc_loss.item()

        self.items["real"] = real
        self.items["real_domain"] = real_domain
        self.items["random_noise"] = random_noise
        self.items["random_domain"] = random_domain
        self.items["random_noise1"] = random_noise1
        self.items["random_noise2"] = random_noise2
        self.items["random_domain1"] = random_domain1
        self.items["logit_fake"] = logit_fake
        self.items["style_fake"] = style_fake
        self.items["style_real"] = style_real
        self.items["fake"] = fake
        self.items["recon"] = image_recon
        self.items["style_recon"] = style_recon

    def train_starGAN(self, init_epoch):
        d_step, g_step = self.config['d_step'], self.config['g_step']
        log_iter = self.config['log_iter']
        image_display_iter = self.config['image_display_iter']
        image_save_iter = self.config['image_save_iter']

        for epoch in range(init_epoch, self.epoch_size):
            self.current_epoch = epoch
            self.save_img_cnt = 0
            for iters, (real, real_domain, _) in enumerate(self.train_loader):
                # self.update_scheduler()

                # real, real_domain = real.to(self.device), real_domain.view(-1, 1, 1).to(self.device)
                real, real_domain = real.to(self.device), real_domain.to(
                    self.device)
                random_noise, random_domain = self.generate_random_nosie()

                if not iters & d_step:
                    self.update_d(real, real_domain, random_noise,
                                  random_domain)

                if not iters % g_step:
                    self.update_g(real, real_domain, random_noise,
                                  random_domain)

                if self.device.type == 'cuda':
                    torch.cuda.synchronize()

                if not (iters + 1) % log_iter:
                    self.print_log(epoch, iters)

                    if not (iters + 1) % image_display_iter:
                        show_batch_torch(torch.cat([
                            real, self.items['fake'].clamp(-1, 1),
                            self.items['recon'].clamp(-1, 1)
                        ]),
                                         n_rows=3,
                                         n_cols=-1)

                        if not (iters + 1) % image_save_iter:
                            self.test_sample = self.generate_test_samples(
                                save=True)
                            clear_jupyter_console()

                # TODO : arbitrary
                if epoch >= 10 and not (iters + 1) % 1000:
                    print("w_ds decayed:", self.w_ds, " -> ", self.w_ds * 0.9)
                    self.w_ds *= 0.9  #

            self.save_models(epoch)

    def print_log(self, epoch, iters):
        adv_d_real = self.loss['adv_d_real']
        adv_d_fake = self.loss['adv_d_fake']
        regul = self.loss['regul']
        adv_g = self.loss['adv_g']
        style_recon_loss = self.loss['style_recon_loss']
        ds_loss = self.loss['ds_loss']
        cyc_loss = self.loss['cyc_loss']

        print(
            "[Epoch {}/{}, iters: {}/{}] " \
            "- Adv: {:5.4} {:5.4} / {:5.4}, Style recon: {:5.4}, DS: {:5.4}, Cyc : {:5.4}, Regul : {:5.4}".format(
                epoch, self.epoch_size, iters + 1, self.iter_size,
                adv_d_real, adv_d_fake, adv_g, style_recon_loss, ds_loss, cyc_loss, regul
            )
        )

    def save_models(self, epoch):
        os.makedirs(self.path_model, exist_ok=True)

        state = {
            'generator': self.generator.state_dict(),
            'discriminator': self.discriminator.state_dict(),
            'optimizer_d': self.optimizer_d.state_dict(),
            'optimizer_g': self.optimizer_g.state_dict(),
            # 'scheduler_d': self.scheduler_d.state_dict(),  # TODO
            # 'scheduler_g': self.scheduler_g.state_dict(),
            'w_ds': self.w_ds,
            'current_epoch': epoch,
        }

        save_name = os.path.join(self.path_model, "epoch_{:02}".format(epoch))
        torch.save(state, save_name)

    def load_models(self, epoch=False):
        if not epoch:
            last_model_path = sorted(
                glob.glob(os.path.join(self.path_model, '*')))[-1]
            epoch = int(last_model_path.split('/')[-1].split('_')[1][:2])

        save_name = os.path.join(self.path_model, "epoch_{:02}".format(epoch))
        checkpoint = torch.load(save_name)

        # weight
        self.discriminator.load_state_dict(checkpoint['discriminator'])
        self.generator.load_state_dict(checkpoint['generator'])
        self.optimizer_d.load_state_dict(checkpoint['optimizer_d'])
        self.optimizer_g.load_state_dict(checkpoint['optimizer_g'])
        # self.scheduler_d.load_state_dict(checkpoint['scheduler_d'])
        # self.scheduler_g.load_state_dict(checkpoint['scheduler_g'])
        self.w_ds = checkpoint['w_ds']
        self.current_epoch = checkpoint['current_epoch']
        return epoch

    def resume_train(self, restart_epoch=False):
        restart_epoch = self.load_models(restart_epoch)
        print("Resume Training - Epoch: ", restart_epoch)
        self.train_starGAN(restart_epoch + 1)

    def generate_test_samples(self, save):
        os.makedirs(self.path_sample, exist_ok=True)

        with torch.no_grad():
            reference, reference_domain, _ = next(iter(self.test_loader))
            reference, reference_domain = reference.to(
                self.device), reference_domain.to(self.device)

            style_reference = self.style_encoder(reference, reference_domain)
            style_reference = style_reference.repeat(1, reference.size(0),
                                                     1).view(
                                                         -1, 1, self.dim_style)
            source = self.test_source.repeat(reference.size(0), 1, 1,
                                             1).view(-1, 3, self.height,
                                                     self.width)
            generated = self.generator(source, style_reference).clamp(-1, 1)

            right_concat, _, _ = reshape_batch_torch(
                torch.cat([self.test_source, generated]),
                n_cols=self.test_batch_size,
                n_rows=-1)

            left_concat = torch.cat(
                [torch.zeros_like(reference[:1]), reference])
            left_concat, _, _ = reshape_batch_torch(left_concat,
                                                    n_cols=1,
                                                    n_rows=-1)

            save_image = preprocess(
                np.concatenate([left_concat, right_concat], axis=1))

            if save:
                save_name = os.path.join(
                    self.path_sample,
                    "{:02}_{:02}.jpg".format(self.current_epoch,
                                             self.save_img_cnt))
                self.save_img_cnt += 1
                plt.imsave(save_name, save_image)
                print("Test samples Saved:" + save_name)
        return save_image
Exemple #18
0
class SAGAN:
    def __init__(self, args):
        self.args = args

        self.gen_model = Generator(args.channels, args.image_size, args.latent_dim, args.ngf)
        self.dis_model = Discriminator(args.channels, args.image_size, args.ndf)
        self.gen_opt = torch.optim.Adam(self.gen_model.parameters(), lr = args.gen_lr, betas = (args.beta1, args.beta2), weight_decay = args.weight_decay)
        self.dis_opt = torch.optim.Adam(self.dis_model.parameters(), lr = args.dis_lr, betas = (args.beta1, args.beta2), weight_decay = args.weight_decay)
        self.anime_dataset = AnimeDataset(args.base_image_path)
        self.train_loader = DataLoader(self.anime_dataset, batch_size = args.batch_size, shuffle = True, drop_last = False)
    
    def train_one_epoch(self, epoch):
        self.gen_model.train()
        self.dis_model.train()
        
        print('[INFO] Epoch:', epoch)

        pbar = tqdm(self.train_loader, total = len(self.train_loader))
        acc_d_loss, acc_g_loss = 0, 0
        for images in pbar:
            images = images.to(self.args.device).float()
            latents = torch.randn(self.args.batch_size, self.args.latent_dim, 1, 1, device = self.args.device).float()
            fake_images = self.gen_model(latents)
            dis_real = self.dis_model(images)
            dis_fake = self.dis_model(fake_images.detach())

            self.dis_opt.zero_grad()
            dis_loss = dis_hinge_loss(dis_fake, dis_real)
            dis_loss.backward()
            self.dis_opt.step()

            latents = torch.randn(self.args.batch_size, self.args.latent_dim, 1, 1, device = self.args.device).float()
            fake_images = self.gen_model(latents)
            dis_fake = self.dis_model(fake_images)

            self.gen_opt.zero_grad()
            gen_loss = gen_hinge_loss(dis_fake)
            gen_loss.backward()
            self.gen_opt.step()

            acc_d_loss = acc_d_loss * self.args.loss_smooth + dis_loss.detach().cpu().item() * (1 - self.args.loss_smooth)
            acc_g_loss = acc_g_loss * self.args.loss_smooth * gen_loss.detach().cpu().item() * (1 - self.args.loss_smooth)

    def visualize(self, num_samples = 20):
        self.gen_model.eval()
        self.dis_model.eval()

        latents = torch.randn(num_samples, self.args.latent_dim, 1, 1)
        fake_images = self.gen_model(latents).detach().cpu().numpy()
        fake_images = fake_images * .5 + .5
        fake_images = np.transpose(fake_images, (0, 2, 3, 1))
        plt.figure(figsize = (10, 10))
        for i in range(num_samples):
            plt.subplot(4, num_samples // 4, i + 1)
            plt.imshow(fake_images[i])
        plt.savefig(str(time()) + '.jpg')

    def save_checkpoints(self, epoch):
        if not os.path.exists(self.args.checkpoints_path):
            os.mkdir(self.args.checkpoints_path)
        torch.save({
            'gen_model': self.gen_model.state_dict(),
            'dis_model': self.dis_model.state_dict(),
            'gen_opt': self.gen_opt.state_dict(),
            'dis_opt': self.dis_model.state_dict()
        }, f'{self.args.checkpoints_path}/epoch_{epoch}_{time()}.tar')

    def train(self):
        for epoch in range(self.args.epochs):
            self.train_one_epoch(epoch + 1)
            self.visualize()
            if epoch == 0 or (epoch + 1) % self.args.checkpoint_step:
                self.save_checkpoints(epoch + 1)
Exemple #19
0
def main():
    
    parser = argparse.ArgumentParser() 
    parser.add_argument('--input_dir', help = 'Directory containing xxx_i_s and xxx_i_t with same prefix',
                        default = cfg.example_data_dir)
    parser.add_argument('--save_dir', help = 'Directory to save result', default = cfg.predict_result_dir)
    parser.add_argument('--checkpoint', help = 'ckpt', default = cfg.ckpt_path)
    args = parser.parse_args()

    assert args.input_dir is not None
    assert args.save_dir is not None
    assert args.checkpoint is not None

    print_log('model compiling start.', content_color = PrintColor['yellow'])

    G = Generator(in_channels = 3).to(device)
    D1 = Discriminator(in_channels = 6).to(device)
    D2 = Discriminator(in_channels = 6).to(device)  
    vgg_features = Vgg19().to(device)   
      
    G_solver = torch.optim.Adam(G.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D1_solver = torch.optim.Adam(D1.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))
    D2_solver = torch.optim.Adam(D2.parameters(), lr=cfg.learning_rate, betas = (cfg.beta1, cfg.beta2))

    checkpoint = torch.load(args.checkpoint)
    G.load_state_dict(checkpoint['generator'])
    D1.load_state_dict(checkpoint['discriminator1'])
    D2.load_state_dict(checkpoint['discriminator2'])
    G_solver.load_state_dict(checkpoint['g_optimizer'])
    D1_solver.load_state_dict(checkpoint['d1_optimizer'])
    D2_solver.load_state_dict(checkpoint['d2_optimizer'])

    trfms = To_tensor()
    example_data = example_dataset(data_dir= args.input_dir, transform = trfms)
    example_loader = DataLoader(dataset = example_data, batch_size = 1, shuffle = False)
    example_iter = iter(example_loader)

    print_log('Model compiled.', content_color = PrintColor['yellow'])

    print_log('Predicting', content_color = PrintColor['yellow'])

    G.eval()
    D1.eval()
    D2.eval()

    with torch.no_grad():

      for step in tqdm(range(len(example_data))):

        try:

          inp = example_iter.next()

        except StopIteration:

          example_iter = iter(example_loader)
          inp = example_iter.next()

        i_t = inp[0].to(device)
        i_s = inp[1].to(device)
        name = str(inp[2][0])

        o_sk, o_t, o_b, o_f = G(i_t, i_s, (i_t.shape[2], i_t.shape[3]))

        o_sk = o_sk.squeeze(0).detach().to('cpu')
        o_t = o_t.squeeze(0).detach().to('cpu')
        o_b = o_b.squeeze(0).detach().to('cpu')
        o_f = o_f.squeeze(0).detach().to('cpu')

        if not os.path.exists(args.save_dir):
            os.makedirs(args.save_dir)

        o_sk = F.to_pil_image(o_sk)
        o_t = F.to_pil_image((o_t + 1)/2)
        o_b = F.to_pil_image((o_b + 1)/2)
        o_f = F.to_pil_image((o_f + 1)/2)
                        
        o_f.save(os.path.join(args.save_dir, name + 'o_f.png'))
Exemple #20
0
                ) * hr_img.size(0)

            torch.save(generator_net.state_dict(),
                       weights_dir + 'G_epoch_%d.pth' % (epoch))
            generator_losses.append(
                (epoch, generator_running_loss / len(train_set)))
            discriminator_losses.append(
                (epoch, discriminator_running_loss / len(train_set)))

            if epoch % 50 == 0:

                with torch.no_grad():
                    cur_epoch_dir = imgout_dir + str(epoch) + '/'
                    os.makedirs(cur_epoch_dir, exist_ok=True)
                    generator_net.eval()
                    discriminator_net.eval()
                    valid_bar = tqdm(validloader)
                    img_count = 0
                    psnr_avg = 0.0
                    psnr = 0.0
                    for hr_img, lr_img in valid_bar:
                        valid_bar.set_description('Img: %i   PSNR: %f' %
                                                  (img_count, psnr))
                        if torch.cuda.is_available():
                            lr_img = lr_img.cuda()
                            hr_img = hr_img.cuda()
                        sr_tensor = generator_net(lr_img)
                        mse = torch.mean((hr_img - sr_tensor)**2)
                        psnr = 10 * (torch.log10(1 / mse) + np.log10(4))
                        psnr_avg += psnr
                        img_count += 1
class ModelBuilder(object):
    def __init__(self, use_cuda):
        self.cuda = use_cuda
        self._pre_data()
        self._build_model()
        self.i_mb = 0

    def _pre_data(self):
        print('pre data...')
        self.data = Data(self.cuda)

    def _build_model(self):
        print('building model...')
        we = torch.load('./data/processed/we.pkl')
        self.i_encoder = CNN_Args_encoder(we)
        self.a_encoder = CNN_Args_encoder(we, need_kmaxavg=True)
        self.classifier = Classifier()
        self.discriminator = Discriminator()
        if self.cuda:
            self.i_encoder.cuda()
            self.a_encoder.cuda()
            self.classifier.cuda()
            self.discriminator.cuda()
        self.criterion_c = torch.nn.CrossEntropyLoss()
        self.criterion_d = torch.nn.BCELoss()
        para_filter = lambda model: filter(lambda p: p.requires_grad,
                                           model.parameters())
        self.i_optimizer = torch.optim.Adagrad(para_filter(self.i_encoder),
                                               Config.lr,
                                               weight_decay=Config.l2_penalty)
        self.a_optimizer = torch.optim.Adagrad(para_filter(self.a_encoder),
                                               Config.lr,
                                               weight_decay=Config.l2_penalty)
        self.c_optimizer = torch.optim.Adagrad(self.classifier.parameters(),
                                               Config.lr,
                                               weight_decay=Config.l2_penalty)
        self.d_optimizer = torch.optim.Adam(self.discriminator.parameters(),
                                            Config.lr_d,
                                            weight_decay=Config.l2_penalty)

    def _print_train(self, epoch, time, loss, acc):
        print('-' * 80)
        print(
            '| end of epoch {:3d} | time: {:5.2f}s | loss: {:10.5f} | acc: {:5.2f}% |'
            .format(epoch, time, loss, acc * 100))
        print('-' * 80)

    def _print_eval(self, task, loss, acc):
        print('| ' + task +
              ' loss {:10.5f} | acc {:5.2f}% |'.format(loss, acc * 100))
        print('-' * 80)

    def _save_model(self, model, filename):
        torch.save(model.state_dict(), './weights/' + filename)

    def _load_model(self, model, filename):
        model.load_state_dict(torch.load('./weights/' + filename))

    def _pretrain_i_one(self):
        self.i_encoder.train()
        self.classifier.train()
        total_loss = 0
        correct_n = 0
        for a1, a2i, a2a, sense in self.data.train_loader:
            if self.cuda:
                a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense.cuda()
            a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable(
                a2a), Variable(sense)

            output = self.classifier(self.i_encoder(a1, a2i))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data

            loss = self.criterion_c(output, sense)
            self.i_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.i_optimizer.step()
            self.c_optimizer.step()

            total_loss += loss.data * sense.size(0)
        return total_loss[0] / self.data.train_size, correct_n[
            0] / self.data.train_size

    def _pretrain_i_a_one(self):
        self.i_encoder.train()
        self.a_encoder.train()
        self.classifier.train()
        total_loss = 0
        correct_n = 0
        total_loss_a = 0
        correct_n_a = 0
        for a1, a2i, a2a, sense in self.data.train_loader:
            if self.cuda:
                a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense.cuda()
            a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable(
                a2a), Variable(sense)

            # train i
            output = self.classifier(self.i_encoder(a1, a2i))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data

            loss = self.criterion_c(output, sense)
            self.i_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.i_optimizer.step()
            self.c_optimizer.step()

            total_loss += loss.data * sense.size(0)

            #train a
            output = self.classifier(self.a_encoder(a1, a2a))
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n_a += torch.sum(tmp).data

            loss = self.criterion_c(output, sense)
            self.a_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.a_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.a_optimizer.step()
            self.c_optimizer.step()

            total_loss_a += loss.data * sense.size(0)
        return total_loss[0] / self.data.train_size, correct_n[
            0] / self.data.train_size, total_loss_a[
                0] / self.data.train_size, correct_n_a[0] / self.data.train_size

    def _adtrain_one(self, acc_d_for_train):
        total_loss = 0
        total_loss_2 = 0
        correct_n = 0
        correct_n_d = 0
        correct_n_d_for_train = 0
        for a1, a2i, a2a, sense in self.data.train_loader:
            if self.cuda:
                a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense.cuda()
            a1, a2i, a2a, sense = Variable(a1), Variable(a2i), Variable(
                a2a), Variable(sense)

            # phase 1, train discriminator
            flag = 0
            for k in range(Config.kd):
                # if self._test_d() != 1:
                if True:
                    temp_d = 0
                    self.a_encoder.eval()
                    self.i_encoder.eval()
                    self.discriminator.train()
                    self.d_optimizer.zero_grad()
                    output_i = self.discriminator(self.i_encoder(a1, a2i))
                    temp_d += torch.sum((output_i < 0.5).long()).data
                    # zero_tensor = torch.zeros(output_i.size())
                    zero_tensor = torch.Tensor(output_i.size()).random_(
                        0, 100) * 0.003
                    if self.cuda:
                        zero_tensor = zero_tensor.cuda()
                    zero_tensor = Variable(zero_tensor)
                    d_loss_i = self.criterion_d(output_i, zero_tensor)
                    d_loss_i.backward()
                    output_a = self.discriminator(self.a_encoder(a1, a2a))
                    temp_d += torch.sum((output_a >= 0.5).long()).data
                    # one_tensor = torch.ones(output_a.size())
                    # one_tensor = torch.Tensor(output_a.size()).fill_(Config.alpha)
                    one_tensor = torch.Tensor(output_a.size()).random_(
                        0, 100) * 0.005 + 0.7
                    if self.cuda:
                        one_tensor = one_tensor.cuda()
                    one_tensor = Variable(one_tensor)
                    d_loss_a = self.criterion_d(output_a, one_tensor)
                    d_loss_a.backward()
                    correct_n_d_for_train += temp_d
                    temp_d = max(temp_d[0] / sense.size(0) / 2,
                                 acc_d_for_train)
                    if temp_d < Config.thresh_high:
                        torch.nn.utils.clip_grad_norm(
                            self.discriminator.parameters(), Config.grad_clip)
                        self.d_optimizer.step()

            # phase 2, train i/c
            self.i_encoder.train()
            self.classifier.train()
            self.discriminator.eval()
            self.i_optimizer.zero_grad()
            self.c_optimizer.zero_grad()
            sent_repr = self.i_encoder(a1, a2i)

            output = self.classifier(sent_repr)
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense.size()
            tmp = (output_sense == sense).long()
            correct_n += torch.sum(tmp).data
            loss_1 = self.criterion_c(output, sense)

            output_d = self.discriminator(sent_repr)
            correct_n_d += torch.sum((output_d < 0.5).long()).data
            one_tensor = torch.ones(output_d.size())
            # one_tensor = torch.Tensor(output_d.size()).fill_(Config.alpha)
            # one_tensor = torch.Tensor(output_d.size()).random_(0,100) * 0.005 + 0.7
            if self.cuda:
                one_tensor = one_tensor.cuda()
            one_tensor = Variable(one_tensor)
            loss_2 = self.criterion_d(output_d, one_tensor)

            loss = loss_1 + loss_2 * Config.lambda1
            loss.backward()
            torch.nn.utils.clip_grad_norm(self.i_encoder.parameters(),
                                          Config.grad_clip)
            torch.nn.utils.clip_grad_norm(self.classifier.parameters(),
                                          Config.grad_clip)
            self.i_optimizer.step()
            self.c_optimizer.step()

            total_loss += loss.data * sense.size(0)
            total_loss_2 += loss_2.data * sense.size(0)

            test_loss, test_acc = self._eval('test', 'i')
            self.logwriter.add_scalar('acc/test_acc_t_mb', test_acc * 100,
                                      self.i_mb)
            self.i_mb += 1

        return total_loss[0] / self.data.train_size, correct_n[
            0] / self.data.train_size, correct_n_d[
                0] / self.data.train_size, total_loss_2[
                    0] / self.data.train_size, correct_n_d_for_train[
                        0] / self.data.train_size / 2

    def _pretrain_i(self):
        best_test_acc = 0
        for epoch in range(Config.pre_i_epochs):
            start_time = time.time()
            loss, acc = self._pretrain_i_one()
            self._print_train(epoch, time.time() - start_time, loss, acc)
            self.logwriter.add_scalar('loss/train_loss_i', loss, epoch)
            self.logwriter.add_scalar('acc/train_acc_i', acc * 100, epoch)

            dev_loss, dev_acc = self._eval('dev', 'i')
            self._print_eval('dev', dev_loss, dev_acc)
            self.logwriter.add_scalar('loss/dev_loss_i', dev_loss, epoch)
            self.logwriter.add_scalar('acc/dev_acc_i', dev_acc * 100, epoch)

            test_loss, test_acc = self._eval('test', 'i')
            self._print_eval('test', test_loss, test_acc)
            self.logwriter.add_scalar('loss/test_loss_i', test_loss, epoch)
            self.logwriter.add_scalar('acc/test_acc_i', test_acc * 100, epoch)
            if test_acc >= best_test_acc:
                best_test_acc = test_acc
                self._save_model(self.i_encoder, 'i.pkl')
                self._save_model(self.classifier, 'c.pkl')
                print('i_model saved at epoch {}'.format(epoch))

    def _adjust_learning_rate(self, optimizer, lr):
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

    def _train_together(self):
        best_test_acc = 0
        loss = acc = loss_a = acc_a = 0
        lr_t = Config.lr_t
        acc_d_for_train = 0
        for epoch in range(Config.together_epochs):
            start_time = time.time()
            if epoch < Config.first_stage_epochs:
                loss, acc, loss_a, acc_a = self._pretrain_i_a_one()
            else:
                if epoch == Config.first_stage_epochs:
                    self._adjust_learning_rate(self.i_optimizer, lr_t)
                    self._adjust_learning_rate(self.c_optimizer, lr_t / 2)
                # elif (epoch - Config.first_stage_epochs) % 20 == 0:
                #     lr_t *= 0.8
                #     self._adjust_learning_rate(self.i_optimizer, lr_t)
                #     self._adjust_learning_rate(self.c_optimizer, lr_t)
                loss, acc, acc_d, loss_2, acc_d_for_train = self._adtrain_one(
                    acc_d_for_train)
            self._print_train(epoch, time.time() - start_time, loss, acc)
            self.logwriter.add_scalar('loss/train_loss_t', loss, epoch)
            self.logwriter.add_scalar('acc/train_acc_t', acc * 100, epoch)
            self.logwriter.add_scalar('loss/train_loss_t_a', loss_a, epoch)
            self.logwriter.add_scalar('acc/train_acc_t_a', acc_a * 100, epoch)
            if epoch >= Config.first_stage_epochs:
                self.logwriter.add_scalar('acc/train_acc_d', acc_d * 100,
                                          epoch)
                self.logwriter.add_scalar('loss/train_loss_2', loss_2, epoch)
                self.logwriter.add_scalar('acc/acc_d_for_train',
                                          acc_d_for_train * 100, epoch)

            dev_loss, dev_acc = self._eval('dev', 'i')
            dev_loss_a, dev_acc_a = self._eval('dev', 'a')
            self._print_eval('dev', dev_loss, dev_acc)
            self.logwriter.add_scalar('loss/dev_loss_t', dev_loss, epoch)
            self.logwriter.add_scalar('acc/dev_acc_t', dev_acc * 100, epoch)
            self.logwriter.add_scalar('loss/dev_loss_t_a', dev_loss_a, epoch)
            self.logwriter.add_scalar('acc/dev_acc_t_a', dev_acc_a * 100,
                                      epoch)
            if epoch >= Config.first_stage_epochs:
                dev_acc_d = self._eval_d('dev')
                self.logwriter.add_scalar('acc/dev_acc_d', dev_acc_d * 100,
                                          epoch)

            test_loss, test_acc = self._eval('test', 'i')
            test_loss_a, test_acc_a = self._eval('test', 'a')
            self._print_eval('test', test_loss, test_acc)
            self.logwriter.add_scalar('loss/test_loss_t', test_loss, epoch)
            self.logwriter.add_scalar('acc/test_acc_t', test_acc * 100, epoch)
            self.logwriter.add_scalar('loss/test_loss_t_a', test_loss_a, epoch)
            self.logwriter.add_scalar('acc/test_acc_t_a', test_acc_a * 100,
                                      epoch)
            if epoch >= Config.first_stage_epochs:
                test_acc_d = self._eval_d('test')
                self.logwriter.add_scalar('acc/test_acc_d', test_acc_d * 100,
                                          epoch)
            if test_acc >= best_test_acc:
                best_test_acc = test_acc
                self._save_model(self.i_encoder, 't_i.pkl')
                self._save_model(self.classifier, 't_c.pkl')
                print('t_i t_c saved at epoch {}'.format(epoch))

    def train(self, i_or_t):
        print('start training')
        self.logwriter = SummaryWriter(Config.logdir)
        if i_or_t == 'i':
            self._pretrain_i()
        elif i_or_t == 't':
            self._train_together()
        else:
            raise Exception('wrong i_or_t')
        print('training done')

    def _eval(self, task, i_or_a):
        self.i_encoder.eval()
        self.a_encoder.eval()
        self.classifier.eval()
        total_loss = 0
        correct_n = 0
        if task == 'dev':
            data = self.data.dev_loader
            n = self.data.dev_size
        elif task == 'test':
            data = self.data.test_loader
            n = self.data.test_size
        else:
            raise Exception('wrong eval task')
        for a1, a2i, a2a, sense1, sense2 in data:
            if self.cuda:
                a1, a2i, a2a, sense1, sense2 = a1.cuda(), a2i.cuda(), a2a.cuda(
                ), sense1.cuda(), sense2.cuda()
            a1 = Variable(a1, volatile=True)
            a2i = Variable(a2i, volatile=True)
            a2a = Variable(a2a, volatile=True)
            sense1 = Variable(sense1, volatile=True)
            sense2 = Variable(sense2, volatile=True)

            if i_or_a == 'i':
                output = self.classifier(self.i_encoder(a1, a2i))
            elif i_or_a == 'a':
                output = self.classifier(self.a_encoder(a1, a2a))
            else:
                raise Exception('wrong i_or_a')
            _, output_sense = torch.max(output, 1)
            assert output_sense.size() == sense1.size()
            gold_sense = sense1
            mask = (output_sense == sense2)
            gold_sense[mask] = sense2[mask]
            tmp = (output_sense == gold_sense).long()
            correct_n += torch.sum(tmp).data

            loss = self.criterion_c(output, gold_sense)
            total_loss += loss.data * gold_sense.size(0)
        return total_loss[0] / n, correct_n[0] / n

    def _eval_d(self, task):
        self.i_encoder.eval()
        self.a_encoder.eval()
        self.classifier.eval()
        correct_n = 0
        if task == 'train':
            n = self.data.train_size
            for a1, a2i, a2a, sense in self.data.train_loader:
                if self.cuda:
                    a1, a2i, a2a, sense = a1.cuda(), a2i.cuda(), a2a.cuda(
                    ), sense.cuda()
                a1 = Variable(a1, volatile=True)
                a2i = Variable(a2i, volatile=True)
                a2a = Variable(a2a, volatile=True)
                sense = Variable(sense, volatile=True)

                output_i = self.discriminator(self.i_encoder(a1, a2i))
                correct_n += torch.sum((output_i < 0.5).long()).data
                # output_a = self.discriminator(self.a_encoder(a1, a2a))
                # correct_n += torch.sum((output_a >= 0.5).long()).data
        else:
            if task == 'dev':
                data = self.data.dev_loader
                n = self.data.dev_size
            elif task == 'test':
                data = self.data.test_loader
                n = self.data.test_size
            for a1, a2i, a2a, sense1, sense2 in data:
                if self.cuda:
                    a1, a2i, a2a, sense1, sense2 = a1.cuda(), a2i.cuda(
                    ), a2a.cuda(), sense1.cuda(), sense2.cuda()
                a1 = Variable(a1, volatile=True)
                a2i = Variable(a2i, volatile=True)
                a2a = Variable(a2a, volatile=True)
                sense1 = Variable(sense1, volatile=True)
                sense2 = Variable(sense2, volatile=True)

                output_i = self.discriminator(self.i_encoder(a1, a2i))
                correct_n += torch.sum((output_i < 0.5).long()).data
                # output_a = self.discriminator(self.a_encoder(a1, a2a))
                # correct_n += torch.sum((output_a >= 0.5).long()).data
        return correct_n[0] / n

    def _test_d(self):
        acc = self._eval_d('dev')
        phase = -100
        if acc >= Config.thresh_high:
            phase = 1
        elif acc > Config.thresh_low:
            phase = 0
        else:
            phase = -1
        return phase

    def eval(self, stage):
        if stage == 'i':
            self._load_model(self.i_encoder, 'i.pkl')
            self._load_model(self.classifier, 'c.pkl')
            test_loss, test_acc = self._eval('test', 'i')
            self._print_eval('test', test_loss, test_acc)
        elif stage == 't':
            self._load_model(self.i_encoder, 't_i.pkl')
            self._load_model(self.classifier, 't_c.pkl')
            test_loss, test_acc = self._eval('test', 'i')
            self._print_eval('test', test_loss, test_acc)
        else:
            raise Exception('wrong eval stage')
Exemple #22
0
class Solver:

    def __init__(self, loader):
        
        self.loader = loader

        self.c_dim = 4
        
        self.lambda_cls = 10.0
        self.lambda_rec = 10.0
        self.lambda_gp = 10.0

        self.g_lr = 0.0001
        self.d_lr = 0.0001
        self.n_critic = 6
        self.beta1 = 0.5
        self.beta2 = 0.999

        self.smooth_beta = 0.999
        
        self.model_save_step = 1000
        self.lr_update_step = 1000

        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        self.image_size = 256

        self.num_iters = 200000
        self.num_iters_decay = 100000

        self.log_step = 10
        self.sample_step = 10

        # Directories.
        self.log_dir = "log"
        self.sample_dir = "sample"
        self.model_save_dir = "model"
        self.result_dir = "result"
        
        # colors
        self.colors = until.colors
        self.void_classes = until.void_classes
        self.valid_classes = until.valid_classes
        self.class_names = until.class_names
        self.ignore_index = until.ignore_index

        self.n_classes = until.n_classes

        self.label_colours = dict(zip(range(19), self.colors))

        self.class_map = dict(zip(self.valid_classes, range(19)))
        self.class_names = dict(zip(self.class_names, range(19)))
        print(self.class_names)

        self.build_model()

    def build_model(self):
        self.G = Generator(conv_dim=64, c_dim=self.c_dim)
        self.G_test = Generator(conv_dim=64, c_dim=self.c_dim)
        self.D = Discriminator(self.image_size, 64, self.c_dim)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        #self.g_optimizer = torch.optim.RMSprop(self.G.parameters(), lr=self.g_lr, alpha=0.99, eps=1e-8)
        #self.d_optimizer = torch.optim.RMSprop(self.D.parameters(), lr=self.d_lr, alpha=0.99, eps=1e-8)
        
        self.G.to(self.device)
        self.G_test.to(self.device)
        self.D.to(self.device)

        self.update_average(self.G_test, self.G, 0.)

    def eval_model(self):
        self.G.eval()
        self.G_test.eval()
        self.D.eval()

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(resume_iters))
        G_test_path = os.path.join(self.model_save_dir, '{}-G_test.ckpt'.format(resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.G_test.load_state_dict(torch.load(G_test_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))


    def reset_grad(self):
        """Reset the gradient buffers."""
        self.g_optimizer.zero_grad()
        self.d_optimizer.zero_grad()

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = torch.flip(x, [1])
        #out = (x + 1) / 2
        return out.clamp_(0, 1)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx**2, dim=1))
        return torch.mean((dydx_l2norm-1)**2)

    def classification_loss(self, logit, target):
        """Compute binary or softmax cross entropy loss."""
        return F.binary_cross_entropy_with_logits(logit, target, size_average=False) / logit.size(0)

    def update_average(self, model_tgt, model_src, beta):
        toogle_grad(model_src, False)
        toogle_grad(model_tgt, False)

        param_dict_src = dict(model_src.named_parameters())

        for p_name, p_tgt in model_tgt.named_parameters():
            p_src = param_dict_src[p_name]
            assert(p_src is not p_tgt)
            
            p_tgt.copy_(beta*p_tgt + (1. - beta)*p_src)

    def get_zdist(self, dist_name, dim, device=None):
        # Get distribution
        if dist_name == 'uniform':
            low = -torch.ones(dim, device=device)
            high = torch.ones(dim, device=device)
            zdist = distributions.Uniform(low, high)
        elif dist_name == 'gauss':
            mu = torch.zeros(dim, device=device)
            scale = torch.ones(dim, device=device)
            zdist = distributions.Normal(mu, scale)
        else:
            raise NotImplementedError

        # Add dim attribute
        zdist.dim = dim

        return zdist

    def getBatch(self):
        try:
            x_real, label_org = next(self.data_iter)
        except:
            while True:
                try:
                    self.data_iter = iter(self.loader)
                    x_real, label_org = next(self.data_iter)
                    break
                except:
                    #a=0/0
                    pass
        return x_real, label_org

    def onehot(self, label):
        
        label = label.numpy()
        
        label_onehot = np.zeros((label.shape[0],self.n_classes,label.shape[1],label.shape[2])).astype(np.uint8)
        #print(label_onehot)
        for i in range(self.n_classes):
            label_onehot[:,i,:,:] = (label == i)
            
        #print(np.max(label_onehot))
        label_onehot = torch.from_numpy(label_onehot)
        #print(label_onehot.shape)
        return label_onehot.to(self.device)
    
    def to_label(self, label_onehot):
        
        label = np.zeros((label_onehot.shape[0],1,label_onehot.shape[2],label_onehot.shape[3])).astype(np.uint8)
        label[:,0,:,:] = np.argmax(label_onehot, axis=1)
        label = torch.from_numpy(label)

        return label

    def vis(self, real, label_onehot):
        label = self.to_label(label_onehot)
        label = label.numpy()

        label_colors = np.zeros((label.shape[0],3,label.shape[2],label.shape[3])).astype(np.uint8)

        r = label.copy()
        g = label.copy()
        b = label.copy()

        for l in range(0, self.n_classes):
            r[label == l] = self.label_colours[l][0]
            g[label == l] = self.label_colours[l][1]
            b[label == l] = self.label_colours[l][2]

        r = np.reshape(r, ((label.shape[0], label.shape[2], label.shape[3])))
        g = np.reshape(g, ((label.shape[0], label.shape[2], label.shape[3])))
        b = np.reshape(b, ((label.shape[0], label.shape[2], label.shape[3])))

        rgb = np.zeros((label.shape[0], 3, label.shape[2], label.shape[3]))
        rgb[:, 0, :, :] = r / 255.0
        rgb[:, 1, :, :] = g / 255.0
        rgb[:, 2, :, :] = b / 255.0

        rgb = torch.from_numpy(rgb)
        save_image(rgb, "label.jpg", nrow=1, padding=0)
        save_image(self.denorm(real.data.cpu()), "real.jpg", nrow=1, padding=0)
        #print(label)


    # それぞれのラベルが何%を占めているか
    def label_contain_persent(self, label, index=None):
        #num_labels=255
        label_per = np.zeros((label.shape[0],self.n_classes,1,1)).astype(np.float32)
        Ns = torch.sum(label<self.n_classes, (1,2), dtype=torch.float32)
        
        if index is None:
            for i in range(self.n_classes):
                label_per[:,i,0,0] = torch.sum(label==i, (1,2), dtype=torch.float32) / Ns 
        else:
            for i in range(len(index)):
                #print(torch.sum(label==index[i], (1,2), dtype=torch.float32))
                label_per[i,index[i],0,0] = torch.sum(label==index[i], (1,2), dtype=torch.float32)[i] / Ns[i]
        label_per = torch.from_numpy(label_per)
        
        #print(torch.sum(label_per, (1)))
        return label_per.view(label_per.size()[0], -1).to(self.device)

    def train(self, start_iter=0):

        g_lr = self.g_lr
        d_lr = self.d_lr

        zdist = None

        BCELoss = torch.nn.BCELoss()

        if start_iter > 0:
            self.restore_model(start_iter)

        # Start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iter, self.num_iters):
            
            x_real, label = self.getBatch()

            label = label.clone()
            label_onehot = self.onehot(label)

            #print(c_org)

            # input images
            x_real = x_real.to(self.device)
            
            if zdist is None:
                zdist = self.get_zdist("uniform", (3,x_real.size(2),x_real.size(3)), device=self.device)

            # make noise
            noise = zdist.sample((x_real.size(0),))

            

            if (i) % 1 == 0:
                # train discriminator
                toogle_grad(self.G, False)
                toogle_grad(self.D, True)

                #print(x_real.shape, label_org.shape)
                #print(x_real.shape)
                self.vis(x_real, label_onehot)

                # 隠したいカテゴリ
                hidden_categorys = [np.random.randint(self.n_classes) for _ in range(x_real.size()[0])]
                hidden_categorys = [self.class_names['car'] for _ in range(x_real.size()[0])]
                # onehotに変換
                hidden_categorys_onehot = np.eye(self.n_classes, dtype=np.float32)[hidden_categorys]           # one hot表現に変換
                hidden_categorys_onehot = torch.from_numpy(hidden_categorys_onehot).to(self.device)
                #print(hidden_categorys_onehot)

                # 教師データにそれぞれ何割のラベルが付与されているか
                label_per_real = self.label_contain_persent(label)
                
                #print(label_per_real) # shape [batch, 19, 1, 1]
                out_src, out_cls_real = self.D(x_real)
                #label_real = torch.full((x_real.size(0),1), 1.0, device=self.device)
                
                d_loss_real = -torch.mean(out_src)
                
                # クラス割合loss
                d_loss_cls_real = naive_cross_entropy_loss(out_cls_real, label_per_real) 

                #print(d_loss_cls_real) # shape 1
                
                x_mask = self.G(x_real, hidden_categorys_onehot)
                x_fake = x_mask * x_real + (1.0-x_mask) * noise
                out_src_fake, out_cls_fake = self.D(x_fake.detach())

                #label_fake = torch.full((x_real.size(0),1), 0.0, device=self.device)

                # クラス割合loss
                d_loss_cls_fake = naive_cross_entropy_loss(out_cls_fake, label_per_real) 
                d_loss_fake = torch.mean(out_src_fake)

                # gp_loss
                alpha = torch.rand(x_real.size(0), 1, 1, 1).to(self.device)
                x_hat = (alpha * x_real.data + (1.0 - alpha) * x_fake.data).requires_grad_(True)
                out_src, _ = self.D(x_hat)
                d_loss_gp = self.gradient_penalty(out_src, x_hat)

                d_loss =d_loss_real + d_loss_fake + self.lambda_cls * (d_loss_cls_real+d_loss_cls_fake) + self.lambda_gp * d_loss_gp
                
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # Logging.
                loss = {}
                loss['D/loss_real'] = d_loss_real.item()
                loss['D/loss_fake'] = d_loss_fake.item()
                loss['D/loss_cls_real'] = d_loss_cls_real.item()
                loss['D/loss_cls_fake'] = d_loss_cls_fake.item()
                loss['D/loss_gp'] = d_loss_gp.item()

            # train generator
            if (i+1) % self.n_critic == 0:
                toogle_grad(self.G, True)
                toogle_grad(self.D, False)
                x_mask = self.G(x_real, hidden_categorys_onehot)
                x_fake = x_mask * x_real + (1.0-x_mask) * noise
                out_src, out_cls = self.D(x_fake)

                label_real = torch.full((x_real.size(0),1), 1.0, device=self.device)
                
                g_loss_fake = -torch.mean(out_src)
                g_loss_cls = self.classification_loss(-out_cls+1.0, hidden_categorys_onehot) #naive_cross_entropy_loss(-out_cls+1.0, label_per_real)

                # backward
                g_loss = g_loss_fake + self.lambda_cls * g_loss_cls
                
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # smoothing
                self.update_average(self.G_test, self.G, self.smooth_beta)


                # Logging.
                loss['G/loss_fake'] = g_loss_fake.item()
                loss['G/loss_cls'] = g_loss_cls.item()

            # Print out training information.
            if (i+1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(et, i+1, self.num_iters)
                for tag, value in loss.items():
                    log += ", {}: {:.4f}".format(tag, value)
                print(log)


            # Translate fixed images for debugging.
            if (i+1) % self.sample_step == 0:
                x_fake_list = [x_real]
                x_fake_list.append(x_fake)
                #x_fake_list.append(x_reconst)
                x_concat = torch.cat(x_fake_list, dim=3)
                sample_path = os.path.join(self.sample_dir, '{}-images.jpg'.format(i+1))
                save_image(self.denorm(x_concat.data.cpu()), sample_path, nrow=1, padding=0)
                

            # Save model checkpoints.
            if (i+1) % self.model_save_step == 0:
                G_path = os.path.join(self.model_save_dir, '{}-G.ckpt'.format(i+1))
                G_test_path = os.path.join(self.model_save_dir, '{}-G_test.ckpt'.format(i+1))
                D_path = os.path.join(self.model_save_dir, '{}-D.ckpt'.format(i+1))
                torch.save(self.G.state_dict(), G_path)
                torch.save(self.G_test.state_dict(), G_test_path)
                torch.save(self.D.state_dict(), D_path)
                print('Saved model checkpoints into {}...'.format(self.model_save_dir))

            # Decay lr
            if (i+1) % self.lr_update_step == 0 and (i+1) > (self.num_iters - self.num_iters_decay):
                g_lr -= (self.g_lr / float(self.num_iters_decay))
                d_lr -= (self.d_lr / float(self.num_iters_decay))
                self.update_lr(g_lr, d_lr)
                print ('Decayed learning rates, g_lr: {}, d_lr: {}.'.format(g_lr, d_lr))

    def test(self, test_iters=None):
        """Translate images using StarGAN trained on a single dataset."""
        # Load the trained generator.
        if test_iters is not None:
            self.restore_model(test_iters)

        #self.eval_model()
            
        # Set data loader.
        data_loader = self.loader
            
        with torch.no_grad():
            for i, (x_real, c_org) in enumerate(data_loader):

                # Prepare input images and target domain labels.
                x_real = x_real.to(self.device)
                c_trg_list = []
                for j in range(self.c_dim):
                    c_trg = c_org.clone()
                    c_trg[:,:] = 0.0
                    c_trg[:,j] = 1.0
                    c_trg_list.append(c_trg.to(self.device))
                
                # Translate images.
                x_fake_list = []
                
                for c_trg in c_trg_list:
                    x_fake_list.append(self.G_test(x_real, c_trg))
                print(x_fake_list[0])

                # Save the translated images.
                try:
                    x_concat = torch.cat(x_fake_list, dim=3)
                    result_path = os.path.join(self.result_dir, '{}-images.jpg'.format(i+1))
                    save_image(self.denorm(x_concat.data.cpu()), result_path, nrow=1, padding=0)
                    print('Saved real and fake images into {}...'.format(result_path))
                except:
                    import traceback
                    traceback.print_exc()
                    print('Error {}...'.format(result_path))
            loss_D_h2l_log.append(loss_D_h2l.item())
            loss_D_l2h_log.append(loss_D_l2h.item())
            loss_G_h2l_log.append(loss_G_h2l.item())
            loss_G_l2h_log.append(loss_G_l2h.item())
            loss_cycle_log.append(loss_cycle.item())

            loss_all = (loss_D_h2l_log, loss_D_l2h_log, loss_G_h2l_log,
                        loss_G_l2h_log, loss_cycle_log)

            file = open("loss_log.txt", "w")
            json.dump(loss_all, file)
            file.close()

        print("\n Testing and saving...")
        G_h2l.eval()
        D_h2l.eval()
        G_l2h.eval()
        D_l2h.eval()

        if ep % 10 == 0:
            for i, sample in enumerate(test_loader):
                if i >= num_test:
                    break
                low_temp = sample["img16"].numpy()
                low = torch.from_numpy(
                    np.ascontiguousarray(low_temp[:, ::-1, :, :])).cuda()
                with torch.no_grad():
                    hign_gen = G_l2h(low)
                np_low = low.cpu().numpy().transpose(0, 2, 3, 1).squeeze(0)
                np_gen = hign_gen.detach().cpu().numpy().transpose(
                    0, 2, 3, 1).squeeze(0)
Exemple #24
0
            real_loss = bce(Dreal, y_real)

            Dfake = D(fake)
            y_fake = torch.zeros_like(Dfake).to(device)
            fake_loss = bce(Dfake, y_fake)
            last_gen_loss = torch.mean(Dfake)

            total_loss = real_loss + fake_loss
            total_loss.backward()
            D_optim.step()
            D_losses.append(total_loss.item())
    print("Epoch %s, G_loss: %f, D_loss: %f, time: %.3f, lr: %.3f" %
          (epoch, Gloss, total_loss, time.time() - train_t, lr))
    if epoch % 10 == 0:
        with torch.no_grad():
            D.eval()
            G.eval()
            plt.figure()
            z_ = sample_z(batch_size, z_dim).to(device).view(-1, z_dim, 1, 1)
            fake = G(z_).squeeze()
            fake = fake.cpu().numpy()
            fake = fake[0, :, :]
            fake[fake < 0] = 0
            plt.imshow(fake)
            plt.colorbar()
            plt.savefig('./plots/' + str(epoch) + ".png")
            plt.close()
    if (epoch + 1) % 50 == 0:
        torch.save(G, save_path + 'G_epoch' + str(epoch))
        torch.save(D, save_path + 'D_epoch' + str(epoch))
        print('Saved Model')
Exemple #25
0
    def train(self, src_data, tgt_data):
        params = self.params
        print(params)
        penalty = 10.0  # penalty on cosine similarity
        print('Subword penalty {}'.format(penalty))
        # Load data
        if not os.path.exists(params.data_dir):
            raise "Data path doesn't exists: %s" % params.data_dir

        src_lang = params.src_lang
        tgt_lang = params.tgt_lang
        self.suffix_str = src_lang + '_' + tgt_lang

        evaluator = Evaluator(params, src_data=src_data, tgt_data=tgt_data)
        monitor = Monitor(params, src_data=src_data, tgt_data=tgt_data)

        # Initialize subword embedding transformer
        # print('Initializing subword embedding transformer...')
        # src_data['F'].eval()
        # src_optimizer = optim.SGD(src_data['F'].parameters())
        # for _ in trange(128):
        #     indices = np.random.permutation(src_data['seqs'].size(0))
        #     indices = torch.LongTensor(indices)
        #     if torch.cuda.is_available():
        #         indices = indices.cuda()
        #     total_loss = 0
        #     for batch in indices.split(params.mini_batch_size):
        #         src_optimizer.zero_grad()
        #         vecs0 = src_data['vecs'][batch]  # original
        #         vecs = src_data['F'](src_data['seqs'][batch], src_data['E'])
        #         loss = F.mse_loss(vecs0, vecs)
        #         loss.backward()
        #         total_loss += float(loss)
        #         src_optimizer.step()
        # print('Done: final loss = {:.2f}'.format(total_loss))

        src_optimizer = optim.SGD(src_data['F'].parameters(),
                                  lr=params.sw_learning_rate,
                                  momentum=0.9)
        print('Src optim: {}'.format(src_optimizer))
        # Loss function
        loss_fn = torch.nn.BCELoss()

        # Create models
        g = Generator(input_size=params.g_input_size,
                      hidden_size=params.g_hidden_size,
                      output_size=params.g_output_size)

        if self.params.model_file:
            print('Load a model from ' + self.params.model_file)
            g.load(self.params.model_file)

        d = Discriminator(input_size=params.d_input_size,
                          hidden_size=params.d_hidden_size,
                          output_size=params.d_output_size,
                          hyperparams=get_hyperparams(params, disc=True))
        seed = params.seed
        self.initialize_exp(seed)

        if not params.disable_cuda and torch.cuda.is_available():
            print('Use GPU')
            # Move the network and the optimizer to the GPU
            g.cuda()
            d.cuda()
            loss_fn = loss_fn.cuda()

        if self.params.model_file is None:
            print('Initializing G based on distribution')
            # if the relative change of loss values is smaller than tol, stop iteration
            topn = 10000
            tol = 1e-5
            prev_loss, loss = None, None
            g_optimizer = optim.SGD(g.parameters(), lr=0.01, momentum=0.9)

            batches = src_data['seqs'][:topn].split(params.mini_batch_size)
            src_emb = torch.cat([
                src_data['F'](batch, src_data['E']).detach()
                for batch in batches
            ])
            tgt_emb = tgt_data['E'].emb.weight[:topn]
            if not params.disable_cuda and torch.cuda.is_available():
                src_emb = src_emb.cuda()
                tgt_emb = tgt_emb.cuda()
            src_emb = F.normalize(src_emb)
            tgt_emb = F.normalize(tgt_emb)
            src_mean = src_emb.mean(dim=0).detach()
            tgt_mean = tgt_emb.mean(dim=0).detach()
            # src_std = src_emb.std(dim=0).deatch()
            # tgt_std = tgt_emb.std(dim=0).deatch()

            for _ in trange(1000):  # at most 1000 iterations
                prev_loss = loss
                g_optimizer.zero_grad()
                mapped_src_mean = g(src_mean)
                loss = F.mse_loss(mapped_src_mean, tgt_mean)
                loss.backward()
                g_optimizer.step()
                # Orthogonalize
                self.orthogonalize(g.map1.weight.data)
                loss = float(loss)
                if type(prev_loss) is float and abs(prev_loss -
                                                    loss) / prev_loss <= tol:
                    break
            print('Done: final loss = {}'.format(float(loss)))
        evaluator.precision(g, src_data, tgt_data)
        sim = monitor.cosine_similarity(g, src_data, tgt_data)
        print('Cos sim.: {:3f} (+/-{:.3})'.format(sim.mean(), sim.std()))

        d_acc_epochs, g_loss_epochs = [], []

        # Define optimizers
        d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate)
        g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate)
        for epoch in range(params.num_epochs):
            d_losses, g_losses = [], []
            hit = 0
            total = 0
            start_time = timer()

            for mini_batch in range(
                    0, params.iters_in_epoch // params.mini_batch_size):
                for d_index in range(params.d_steps):
                    d_optimizer.zero_grad()  # Reset the gradients
                    d.train()

                    X, y, _ = self.get_batch_data(src_data, tgt_data, g)
                    pred = d(X)
                    d_loss = loss_fn(pred, y)
                    d_loss.backward()
                    d_optimizer.step()

                    d_losses.append(d_loss.data.cpu().numpy())
                    discriminator_decision = pred.data.cpu().numpy()
                    hit += np.sum(
                        discriminator_decision[:params.mini_batch_size] >= 0.5)
                    hit += np.sum(
                        discriminator_decision[params.mini_batch_size:] < 0.5)

                    sys.stdout.write("[%d/%d] :: Discriminator Loss: %f \r" %
                                     (mini_batch, params.iters_in_epoch //
                                      params.mini_batch_size,
                                      np.asscalar(np.mean(d_losses))))
                    sys.stdout.flush()

                    total += 2 * params.mini_batch_size * params.d_steps

                for g_index in range(params.g_steps):
                    # 2. Train G on D's response (but DO NOT train D on these labels)
                    g_optimizer.zero_grad()
                    src_optimizer.zero_grad()
                    d.eval()

                    X, y, src_vecs = self.get_batch_data(src_data, tgt_data, g)
                    pred = d(X)
                    g_loss = loss_fn(pred, 1 - y)
                    src_loss = F.mse_loss(*src_vecs)
                    if g_loss.is_cuda:
                        src_loss = src_loss.cuda()
                    loss = g_loss + penalty * src_loss
                    loss.backward()
                    g_optimizer.step()  # Only optimizes G's parameters
                    src_optimizer.step()

                    g_losses.append(g_loss.data.cpu().numpy())

                    # Orthogonalize
                    self.orthogonalize(g.map1.weight.data)

                    sys.stdout.write(
                        "[%d/%d] ::                                     Generator Loss: %f \r"
                        % (mini_batch,
                           params.iters_in_epoch // params.mini_batch_size,
                           np.asscalar(np.mean(g_losses))))
                    sys.stdout.flush()

                d_acc_epochs.append(hit / total)
                g_loss_epochs.append(np.asscalar(np.mean(g_losses)))
            print(
                "Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f}, Generator Loss: {:.5f}, Time elapsed {:.2f} mins"
                .format(epoch, np.asscalar(np.mean(d_losses)), hit / total,
                        np.asscalar(np.mean(g_losses)),
                        (timer() - start_time) / 60))

            filename = path.join(params.model_dir, 'g_e{}.pth'.format(epoch))
            print('Save a generator to ' + filename)
            g.save(filename)
            filename = path.join(params.model_dir, 's_e{}.pth'.format(epoch))
            print('Save a subword transformer to ' + filename)
            src_data['F'].save(filename)
            if (epoch + 1) % params.print_every == 0:
                evaluator.precision(g, src_data, tgt_data)
                sim = monitor.cosine_similarity(g, src_data, tgt_data)
                print('Cos sim.: {:3f} (+/-{:.3})'.format(
                    sim.mean(), sim.std()))

        return g
def training(opt):

    # ~~~~~~~~~~~~~~~~~~~ hyper parameters ~~~~~~~~~~~~~~~~~~~ #
    EPOCHS = opt.epochs
    CHANNELS = 1
    H, W = 64, 64
    lr = opt.lr
    work_device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    FEATURE_D = 128
    Z_DIM = 100
    GEN_TRAIN_STEPS = 5
    BATCH_SIZE = opt.batch_size

    if opt.logs:
        log_dir = Path(f'{opt.logs}').resolve()
        if log_dir.exists():
            shutil.rmtree(str(log_dir))

    if opt.weights:
        Weight_dir = Path(f'{opt.weights}').resolve()
        if not Weight_dir.exists():
            Weight_dir.mkdir()
    # ~~~~~~~~~~~~~~~~~~~ loading the dataset ~~~~~~~~~~~~~~~~~~~ #

    trans = transforms.Compose([
        transforms.Resize((H, W)),
        transforms.ToTensor(),
        transforms.Normalize((0.5, ), (0.5, ))
    ])

    MNIST_data = MNIST('./data', True, transform=trans, download=True)

    loader = DataLoader(MNIST_data, BATCH_SIZE, True, num_workers=1)

    # ~~~~~~~~~~~~~~~~~~~ creating tensorboard variables ~~~~~~~~~~~~~~~~~~~ #

    writer_fake = SummaryWriter(f"{str(log_dir)}/fake")
    writer_real = SummaryWriter(f"{str(log_dir)}/real")
    loss_writer = SummaryWriter(f"{str(log_dir)}/loss")

    # ~~~~~~~~~~~~~~~~~~~ loading the model ~~~~~~~~~~~~~~~~~~~ #

    disc = Discriminator(img_channels=CHANNELS,
                         feature_d=FEATURE_D).to(work_device)
    gen = Faker(Z_DIM, CHANNELS, FEATURE_D).to(work_device)

    if opt.resume:
        if Path(Weight_dir / 'dirscriminator.pth').exists():

            disc.load_state_dict(
                torch.load(str(Weight_dir / 'dirscriminator.pth'),
                           map_location=work_device))

        if Path(Weight_dir / 'generator.pth').exists():

            gen.load_state_dict(
                torch.load(str(Weight_dir / 'generator.pth'),
                           map_location=work_device))

    # ~~~~~~~~~~~~~~~~~~~ create optimizer and loss ~~~~~~~~~~~~~~~~~~~ #

    disc_optim = optim.Adam(disc.parameters(), lr, (0.5, 0.999))
    gen_optim = optim.Adam(gen.parameters(), lr, (0.5, 0.999))
    criterion = torch.nn.BCELoss()

    # ~~~~~~~~~~~~~~~~~~~ training loop ~~~~~~~~~~~~~~~~~~~ #
    D_loss_prev = math.inf
    G_loss_prev = math.inf
    D_loss = 0
    G_loss = 0

    for epoch in range(EPOCHS):

        for batch_idx, (real, _) in enumerate(tqdm(loader)):
            disc.train()
            gen.train()
            real = real.to(work_device)
            fixed_noise = torch.rand(real.shape[0], Z_DIM, 1,
                                     1).to(work_device)
            # ~~~~~~~~~~~~~~~~~~~ discriminator loop ~~~~~~~~~~~~~~~~~~~ #

            fake = gen(fixed_noise)  # dim of (N,1,28,28)
            # ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
            real_predict = disc(real).view(-1)  # make it one dimensional array
            fake_predict = disc(fake).view(-1)  # make it one dimensional array

            labels = torch.cat([
                torch.ones_like(real_predict),
                torch.zeros_like(fake_predict)
            ],
                               dim=0)

            # ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #
            D_loss = criterion(torch.cat([real_predict, fake_predict], dim=0),
                               labels)

            # ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
            disc.zero_grad()
            D_loss.backward()
            disc_optim.step()

            # ~~~~~~~~~~~~~~~~~~~ generator loop ~~~~~~~~~~~~~~~~~~~ #
            for _ in range(GEN_TRAIN_STEPS):
                # ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
                fake = gen(fixed_noise)
                # ~~~~~~~~~~~~~~~~~~~ forward ~~~~~~~~~~~~~~~~~~~ #
                # make it one dimensional array
                fake_predict = disc(fake).view(-1)
                # ~~~~~~~~~~~~~~~~~~~ loss ~~~~~~~~~~~~~~~~~~~ #

                G_loss = criterion(fake_predict, torch.ones_like(fake_predict))
                # ~~~~~~~~~~~~~~~~~~~ backward ~~~~~~~~~~~~~~~~~~~ #
                gen.zero_grad()
                G_loss.backward()
                gen_optim.step()

            # ~~~~~~~~~~~~~~~~~~~ loading the tensorboard ~~~~~~~~~~~~~~~~~~~ #

            if batch_idx == 0:
                print(
                    f"Epoch [{epoch}/{EPOCHS}] Batch {batch_idx}/{len(loader)} \
                                Loss D: {D_loss:.4f}, loss G: {G_loss:.4f}")

                with torch.no_grad():
                    disc.eval()
                    gen.eval()
                    fake = gen(fixed_noise).reshape(-1, CHANNELS, H, W)
                    data = real.reshape(-1, CHANNELS, H, W)
                    if BATCH_SIZE > 32:
                        fake = fake[:32]
                        data = data[:32]
                    img_grid_fake = torchvision.utils.make_grid(fake,
                                                                normalize=True)
                    img_grid_real = torchvision.utils.make_grid(data,
                                                                normalize=True)

                    writer_fake.add_image("Mnist Fake Images",
                                          img_grid_fake,
                                          global_step=epoch)
                    writer_real.add_image("Mnist Real Images",
                                          img_grid_real,
                                          global_step=epoch)
                    loss_writer.add_scalar('discriminator',
                                           D_loss,
                                           global_step=epoch)
                    loss_writer.add_scalar('generator',
                                           G_loss,
                                           global_step=epoch)

        # ~~~~~~~~~~~~~~~~~~~ saving the weights ~~~~~~~~~~~~~~~~~~~ #
        if opt.weights:
            if D_loss_prev > D_loss:
                D_loss_prev = D_loss
                weight_path = str(Weight_dir / 'dirscriminator.pth')
                torch.save(disc.state_dict(), weight_path)

            if G_loss_prev > G_loss:
                G_loss_prev = G_loss
                weight_path = str(Weight_dir / 'generator.pth')
                torch.save(gen.state_dict(), weight_path)
    def train(self, src_emb, tgt_emb):
        params = self.params
        # Load data
        if not os.path.exists(params.data_dir):
            raise "Data path doesn't exists: %s" % params.data_dir

        en = src_emb
        it = tgt_emb
        self.params = _get_eval_params(params)
        params = self.params

        for _ in range(params.num_random_seeds):

            # Create models
            g = Generator(input_size=params.g_input_size, output_size=params.g_output_size)
            d = Discriminator(input_size=params.d_input_size, hidden_size=params.d_hidden_size, output_size=params.d_output_size)
            print(d)
            lowest_loss = 1e5
            
            g.apply(self.weights_init3)

            seed = random.randint(0, 1000)
            self.initialize_exp(seed)
            
            loss_fn = torch.nn.BCELoss()
            loss_fn2 = torch.nn.CosineSimilarity(dim=1, eps=1e-6)
            #d_optimizer = optim.SGD(d.parameters(), lr=params.d_learning_rate)
            #g_optimizer = optim.SGD(g.parameters(), lr=params.g_learning_rate)
            #d_optimizer = optim.Adam(d.parameters(), lr=params.d_learning_rate)
            d_optimizer = optim.RMSprop(d.parameters(), lr=params.d_learning_rate)
            g_optimizer = optim.Adam(g.parameters(), lr=params.g_learning_rate)

            if torch.cuda.is_available():
                # Move the network and the optimizer to the GPU
                g = g.cuda()
                d = d.cuda()
                loss_fn = loss_fn.cuda()
                loss_fn2 = loss_fn2.cuda()
            # true_dict = get_true_dict(params.data_dir)
            d_acc_epochs = []
            g_loss_epochs = []
            d_loss_epochs = []
            acc_all = []
            d_losses = []
            g_losses = []
            csls_epochs = []
            recon_losses = []
            w_losses = []

            try:
                for epoch in range(params.num_epochs):
                    recon_losses = []
                    w_losses = []
                    start_time = timer()

                    for mini_batch in range(0, params.iters_in_epoch // params.mini_batch_size):
                        hit,total = 0,0
                        for d_index in range(params.d_steps):
                            d_optimizer.zero_grad()  # Reset the gradients
                            d.train()
                            # input, output = self.get_batch_data_fast(en, it, g, detach=True)
                            src_batch, tgt_batch = self.get_batch_data_fast_new(en, it)
                            fake,_ = g(src_batch)
                            fake = fake.detach()
                            real = tgt_batch
                            # input = torch.cat([fake, real], 0)
                            input = torch.cat([real, fake], 0)
                            output = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_())

                            output[:params.mini_batch_size] = 1 - params.smoothing
                            output[params.mini_batch_size:] = params.smoothing

                            pred = d(input)
                            d_loss = loss_fn(pred, output)
                            d_loss.backward()  # compute/store gradients, but don't change params
                            d_losses.append(d_loss.data.cpu().numpy())
                            discriminator_decision = pred.data.cpu().numpy()
                            hit += np.sum(discriminator_decision[:params.mini_batch_size] >= 0.5)
                            hit += np.sum(discriminator_decision[params.mini_batch_size:] < 0.5)
                            d_optimizer.step()  # Only optimizes D's parameters; changes based on stored gradients from backward()

                            # Clip weights
                            _clip(d, params.clip_value)

                            sys.stdout.write("[%d/%d] :: Discriminator Loss: %f \r" % (
                                mini_batch, params.iters_in_epoch // params.mini_batch_size,
                                np.asscalar(np.mean(d_losses[-1000:]))))
                            sys.stdout.flush()

                        total += 2 * params.mini_batch_size * params.d_steps

                        for g_index in range(params.g_steps):
                            # 2. Train G on D's response (but DO NOT train D on these labels)
                            g_optimizer.zero_grad()
                            d.eval()
                            src_batch, tgt_batch = self.get_batch_data_fast_new(en, it)
                            fake, recon = g(src_batch)
                            real = tgt_batch
                            output = to_variable(torch.FloatTensor(2 * params.mini_batch_size).zero_())
                            output[:params.mini_batch_size] = 1 - params.smoothing
                            output[params.mini_batch_size:] = params.smoothing

                            pred = d(fake)
                            output2 = to_variable(torch.FloatTensor(params.mini_batch_size).zero_())
                            output2 = output2+1-params.smoothing
                            
                            recon_loss = 1.0 - torch.mean(loss_fn2(src_batch,recon))
                            g_loss = loss_fn(pred, output2) + params.recon_weight * recon_loss
                            g_loss.backward()
                            
                            g_losses.append(g_loss.data.cpu().numpy())
                            recon_losses.append(recon_loss.data.cpu().numpy())

                            g_optimizer.step()  # Only optimizes G's parameters
                            #self.orthogonalize(g.map1.weight.data)

                            sys.stdout.write("[%d/%d] ::                                     Generator Loss: %f \r" % (
                                mini_batch, params.iters_in_epoch // params.mini_batch_size,
                                np.asscalar(np.mean(g_losses[-1000:]))))
                            sys.stdout.flush()

                        acc_all.append(hit / total)
                        
                        if epoch > params.threshold:
                            if lowest_loss > float(g_loss.data):
                                lowest_loss = float(g_loss.data)
                                W = g.map1.weight.data.cpu().numpy()
                                w_losses.append(np.linalg.norm(np.dot(W.T, W) - np.identity(params.g_input_size)))

                                X_Z = g(src_emb.weight)[0].data
                                Y_Z = tgt_emb.weight.data

                                mstart_time = timer()
                                for method in [params.dico_method]:
                                    results = get_word_translation_accuracy(
                                        'en', self.src_ids, X_Z,
                                        'zh', self.tgt_ids, Y_Z,
                                        method=method,
                                        path = params.data_dir+params.validation_file
                                    )
                                    acc = results[0][1]
                                    #print('{} takes {:.2f}s'.format(method, timer() - mstart_time))
                                    #print('Method:{} score:{:.4f}'.format(method,acc))

                                    torch.save(g.state_dict(),
                                               'tune/best/G_seed{}_epoch_{}_batch_{}_mf_{}_p@1_{:.3f}.t7'.format(seed,epoch,mini_batch,params.most_frequent_sampling_size,acc))
 
                        '''
                        if mini_batch % 500==0:
                            #d_acc_epochs.append(hit / total)
                            #d_loss_epochs.append(np.asscalar(np.mean(d_losses)))
                            #g_loss_epochs.append(np.asscalar(np.mean(g_losses)))
                            if epoch > params.threshold:
                                W = g.map1.weight.data.cpu().numpy()
                                w_loss = np.linalg.norm(np.dot(W.T, W) - np.identity(params.g_input_size))
                                #print("D_acc:{:.3f} d_loss:{:.3f} g_loss:{:.3f} w_loss:{:.2f} ".format(hit / total,np.asscalar(np.mean(d_losses)),np.asscalar(np.mean(g_losses)),w_loss))
                                #print("D_acc:{:.3f} d_loss:{:.3f} g_loss:{:.3f} w_loss:{:.2f}".format(hit / total,d_loss.data[0],g_loss.data[0],w_loss))

                                X_Z = g(src_emb.weight)[0].data
                                Y_Z = tgt_emb.weight.data

                                mstart_time = timer()
                                for method in [params.dico_method]:
                                    results = get_word_translation_accuracy(
                                        'en', self.src_ids, X_Z,
                                        'zh', self.tgt_ids, Y_Z,
                                        method=method,
                                        path = params.validation_file
                                    )
                                    acc = results[0][1]
                                    #print('epoch:{} Method:{} score:{:.4f}'.format(mini_batch,method,acc))
                                    torch.save(g.state_dict(),
                                       'tune/thu/G_seed{}_epoch_{}_mf_{}_p@1_{:.3f}.t7'.format(seed,mini_batch,params.most_frequent_sampling_size,acc))
                        '''
                    X_Z = g(src_emb.weight)[0].data
                    Y_Z = tgt_emb.weight.data
                    print("Epoch {} : Discriminator Loss: {:.5f}, Discriminator Accuracy: {:.5f},Generator Loss: {:.5f}, Time elapsed {:.2f}mins".format(epoch,np.asscalar(np.mean(d_losses[-1562:])), hit / total,np.asscalar(np.mean(g_losses[-1562:])),(timer() - start_time) / 60))

                    mstart_time = timer()
                    for method in [params.dico_method]:
                        results = get_word_translation_accuracy(
                            'en', self.src_ids, X_Z,
                            'zh', self.tgt_ids, Y_Z,
                            method=method,
                            path = params.data_dir+params.validation_file
                        )
                        acc = results[0][1]
                        print('epoch:{} Method:{} score:{:.4f}'.format(epoch,method,acc))
                        torch.save(g.state_dict(),
                            'tune/G_seed{}_epoch_{}_mf_{}_p@1_{:.3f}.t7'.format(seed,epoch,params.most_frequent_sampling_size,acc))

                
                # Save the plot for discriminator accuracy and generator loss
                fig = plt.figure()
                plt.plot(range(0, len(acc_all)), acc_all, color='b', label='discriminator')
                plt.ylabel('D_accuracy_all')
                plt.xlabel('epochs')
                plt.legend()
                fig.savefig('tune/D_acc_all.png')

                fig = plt.figure()
                plt.plot(range(0, len(d_losses)), d_losses, color='b', label='discriminator')
                plt.ylabel('D_loss_all')
                plt.xlabel('epochs')
                plt.legend()
                fig.savefig('tune/D_loss_all.png')

                fig = plt.figure()
                plt.plot(range(0, len(g_losses)), g_losses, color='b', label='discriminator')
                plt.ylabel('G_loss_all')
                plt.xlabel('epochs')
                plt.legend()
                fig.savefig('tune/G_loss_all.png')

                fig = plt.figure()
                plt.plot(range(0, len(w_losses)), w_losses, color='b', label='discriminator')
                plt.ylabel('||W^T*W - I||')
                plt.xlabel('epochs')
                plt.legend()
                fig.savefig('tune/W^TW.png')

                plt.close('all')

            except KeyboardInterrupt:
                print("Interrupted.. saving model !!!")
                torch.save(g.state_dict(), 'tune/g_model_interrupt.t7')
                torch.save(d.state_dict(), 'tune/d_model_interrupt.t7')
                exit()

        return g
Exemple #28
0
class FNM(object):
    def __init__(self, args):
        os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
        os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id
        self.batch_size = args.batch_size
        self.lr = args.lr
        self.profile_list_path = args.profile_list
        self.front_list_path = args.front_list
        self.profile_path = args.profile_path
        self.front_path = args.front_path
        self.test_path = args.test_path
        self.test_list = args.test_list

        self.crop_size = args.ori_height
        self.image_size = args.height
        self.res_n = args.res_n
        self.is_finetune = args.is_finetune
        self.result_name = args.result_name
        self.summary_dir = args.summary_dir
        self.iteration = args.iteration
        self.weight_decay = args.weight_decay
        self.decay_flag = args.decay_flag
        self.print_freq = args.print_freq
        self.save_freq = args.save_freq
        self.img_size = args.width
        self.model_name = args.model_name

        # For hyper parameters
        self.lambda_l1 = args.lambda_l1
        self.lambda_fea = args.lambda_fea
        self.lambda_reg = args.lambda_reg
        self.lambda_gan = args.lambda_gan
        self.lambda_gp = args.lambda_gp

        self.channel = args.channel
        self.device = torch.device("cuda:{}".format(args.device_id))
        self.make_dirs()
        self.build_model()
        """Define Loss"""
        self.L1_loss = nn.L1Loss().to(self.device)
        self.L2_loss = nn.MSELoss().to(self.device)

    def make_dirs(self):
        check_folder(self.summary_dir)
        check_folder(os.path.join("results", self.result_name, "model"))
        check_folder(os.path.join("results", self.result_name, "img"))

    def build_model(self):
        self.expert_net = se50_net(
            "./other_models/arcface_se50/model_ir_se50.pth").to(self.device)
        for param in self.expert_net.parameters():
            param.requires_grad = False
        #self.dataset = sample_dataset(self.profile_list_path, self.front_list_path, self.profile_path, self.front_path, self.crop_size, self.image_size)
        self.front_loader = get_loader(self.front_list_path,
                                       self.front_path,
                                       self.crop_size,
                                       self.image_size,
                                       self.batch_size,
                                       mode="train",
                                       num_workers=8)

        self.profile_loader = get_loader(self.profile_list_path,
                                         self.profile_path,
                                         self.crop_size,
                                         self.image_size,
                                         self.batch_size,
                                         mode="train",
                                         num_workers=8)

        self.test_loader = get_loader(self.test_list,
                                      self.test_path,
                                      self.crop_size,
                                      self.image_size,
                                      self.batch_size,
                                      mode="test",
                                      num_workers=8)

        #self.front_loader = iter(self.front_loader)
        #self.profile_loader = iter(self.profile_loader)
        #resnet_blocks
        resnet_block_list = []
        for i in range(self.res_n):
            resnet_block_list.append(ResnetBlock(512, use_bias=False))

        self.body = nn.Sequential(*resnet_block_list).to(self.device)
        #[b, 512, 7, 7]
        self.decoder = Decoder().to(self.device)
        self.dis = Discriminator(self.channel).to(self.device)

        self.G_optim = torch.optim.Adam(itertools.chain(
            self.body.parameters(), self.decoder.parameters()),
                                        lr=self.lr,
                                        betas=(0.5, 0.999),
                                        weight_decay=self.weight_decay)

        self.D_optim = torch.optim.Adam(itertools.chain(self.dis.parameters()),
                                        lr=self.lr,
                                        betas=(0.5, 0.999),
                                        weight_decay=self.weight_decay)

        self.downsample112x112 = nn.Upsample(size=(112, 112), mode='bilinear')

    def update_lr(self, start_iter):
        if self.decay_flag and start_iter > (self.iteration // 2):
            self.G_optim.param_groups[0]['lr'] -= (
                self.lr /
                (self.iteration // 2)) * (start_iter - self.iteration // 2)
            self.D_optim.param_groups[0]['lr'] -= (
                self.lr /
                (self.iteration // 2)) * (start_iter - self.iteration // 2)

    def train(self):
        self.body.train(), self.decoder.train(), self.dis.train()
        start_iter = 1
        if self.is_finetune:
            model_list = glob(
                os.path.join("results", self.result_name, "model", "*.pt"))
            if not len(model_list) == 0:
                model_list.sort()
                start_iter = int(model_list[-1].split('_')[-1].split('.')[0])
                self.load(os.path.join("results", self.result_name, 'model'),
                          start_iter)
                print(" [*] Load SUCCESS")
                self.update_lr(start_iter)
        print("training start...")
        start_time = time.time()
        for step in range(start_iter, self.iteration + 1):
            self.update_lr(start_iter)
            try:
                front_224, front_112 = front_iter.next()
                if front_224.shape[0] != self.batch_size:
                    raise Exception
            except:
                front_iter = iter(self.front_loader)
                front_224, front_112 = front_iter.next()
            try:
                profile_224, profile_112 = profile_iter.next()
                if profile_224.shape[0] != self.batch_size:
                    raise Exception
            except:
                profile_iter = iter(self.profile_loader)
                profile_224, profile_112 = profile_iter.next()

            profile_224, front_224, profile_112, front_112 = profile_224.to(
                self.device), front_224.to(self.device), profile_112.to(
                    self.device), front_112.to(self.device)

            # Update D
            self.D_optim.zero_grad()

            feature_p = self.expert_net.get_feature(profile_112)
            feature_f = self.expert_net.get_feature(front_112)
            gen_p = self.decoder(self.body(feature_p))
            gen_f = self.decoder(self.body(feature_f))
            feature_gen_p = self.expert_net.get_feature(
                self.downsample112x112(gen_p))
            feature_gen_f = self.expert_net.get_feature(
                self.downsample112x112(gen_f))
            d_f = self.dis(front_224)
            d_gen_p = self.dis(gen_p)
            d_gen_f = self.dis(gen_f)

            D_adv_loss = torch.mean(
                tensor_tuple_sum(d_gen_f) * 0.5 +
                tensor_tuple_sum(d_gen_p) * 0.5 - tensor_tuple_sum(d_f)) / 5

            alpha = torch.rand(gen_p.size(0), 1, 1, 1).to(self.device)
            inter = (alpha * front_224.data +
                     (1 - alpha) * gen_p.data).requires_grad_(True)
            out_inter = self.dis(inter)
            gradient_penalty_loss = (
                gradient_penalty(out_inter[0], inter, self.device) +
                gradient_penalty(out_inter[1], inter, self.device) +
                gradient_penalty(out_inter[2], inter, self.device) +
                gradient_penalty(out_inter[3], inter, self.device) +
                gradient_penalty(out_inter[4], inter, self.device)) / 5
            #print("gradient_penalty_loss:{}".format(gradient_penalty_loss))
            d_loss = self.lambda_gan * D_adv_loss + self.lambda_gp * gradient_penalty_loss
            d_loss.backward(retain_graph=True)
            self.D_optim.step()

            # Update G
            self.G_optim.zero_grad()
            try:
                front_224, front_112 = front_iter.next()
                if front_224.shape[0] != self.batch_size:
                    raise Exception
            except:
                front_iter = iter(self.front_loader)
                front_224, front_112 = front_iter.next()
            try:
                profile_224, profile_112 = profile_iter.next()
                if profile_224.shape[0] != self.batch_size:
                    raise Exception
            except:
                profile_iter = iter(self.profile_loader)
                profile_224, profile_112 = profile_iter.next()

            profile_224, front_224, profile_112, front_112 = profile_224.to(
                self.device), front_224.to(self.device), profile_112.to(
                    self.device), front_112.to(self.device)

            feature_p = self.expert_net.get_feature(profile_112)
            feature_f = self.expert_net.get_feature(front_112)
            gen_p = self.decoder(self.body(feature_p))
            gen_f = self.decoder(self.body(feature_f))
            feature_gen_p = self.expert_net.get_feature(
                self.downsample112x112(gen_p))
            feature_gen_f = self.expert_net.get_feature(
                self.downsample112x112(gen_f))
            d_f = self.dis(front_224)
            d_gen_p = self.dis(gen_p)
            d_gen_f = self.dis(gen_f)

            pixel_loss = torch.mean(self.L1_loss(front_224, gen_f))

            feature_p_norm = l2_norm(feature_p)
            feature_f_norm = l2_norm(feature_f)
            feature_gen_p_norm = l2_norm(feature_gen_p)
            feature_gen_f_norm = l2_norm(feature_gen_f)

            perceptual_loss = torch.mean(
                0.5 *
                (1 - torch.sum(torch.mul(feature_p_norm, feature_gen_p_norm),
                               dim=(1, 2, 3))) + 0.5 *
                (1 - torch.sum(torch.mul(feature_f_norm, feature_gen_f_norm),
                               dim=(1, 2, 3))))

            G_adv_loss = -torch.mean(
                tensor_tuple_sum(d_gen_f) * 0.5 +
                tensor_tuple_sum(d_gen_p) * 0.5) / 5
            g_loss = self.lambda_gan * G_adv_loss + self.lambda_l1 * pixel_loss + self.lambda_fea * perceptual_loss
            g_loss.backward()
            self.G_optim.step()

            print("[%5d/%5d] time: %4.4f d_loss: %.8f, g_loss: %.8f" %
                  (step, self.iteration, time.time() - start_time, d_loss,
                   g_loss))
            print("D_adv_loss : %.8f" % (self.lambda_gan * D_adv_loss))
            print("G_adv_loss : %.8f" % (self.lambda_gan * G_adv_loss))
            print("pixel_loss : %.8f" % (self.lambda_l1 * pixel_loss))
            print("perceptual_loss : %.8f" %
                  (self.lambda_fea * perceptual_loss))
            print("gp_loss : %.8f" % (self.lambda_gp * gradient_penalty_loss))

            with torch.no_grad():
                if step % self.print_freq == 0:
                    train_sample_num = 5
                    test_sample_num = 5
                    A2B = np.zeros((self.img_size * 4, 0, 3))
                    self.body.eval(), self.decoder.eval(), self.dis.eval()
                    for _ in range(train_sample_num):
                        try:
                            front_224, front_112 = front_iter.next()
                            if front_224.shape[0] != self.batch_size:
                                raise Exception
                        except:
                            front_iter = iter(self.front_loader)
                            front_224, front_112 = front_iter.next()
                        try:
                            profile_224, profile_112 = profile_iter.next()
                            if profile_224.shape[0] != self.batch_size:
                                raise Exception
                        except:
                            profile_iter = iter(self.profile_loader)
                            profile_224, profile_112 = profile_iter.next()

                        profile_224, front_224, profile_112, front_112 = profile_224.to(
                            self.device), front_224.to(
                                self.device), profile_112.to(
                                    self.device), front_112.to(self.device)

                        feature_p = self.expert_net.get_feature(profile_112)
                        feature_f = self.expert_net.get_feature(front_112)
                        gen_p = self.decoder(self.body(feature_p))
                        gen_f = self.decoder(self.body(feature_f))

                        A2B = np.concatenate(
                            (A2B,
                             np.concatenate(
                                 (RGB2BGR(tensor2numpy(denorm(
                                     profile_224[0]))),
                                  RGB2BGR(tensor2numpy(denorm(gen_p[0]))),
                                  RGB2BGR(tensor2numpy(denorm(front_224[0]))),
                                  RGB2BGR(tensor2numpy(denorm(gen_f[0])))),
                                 0)), 1)

                    for _ in range(train_sample_num):
                        show_list = []
                        for i in range(2):
                            try:
                                test_profile_224, test_profile_112 = test_iter.next(
                                )
                                if test_profile_224.shape[0] != self.batch_size:
                                    raise Exception
                            except:
                                test_iter = iter(self.test_loader)
                                test_profile_224, test_profile_112 = test_iter.next(
                                )
                            test_profile_224, test_profile_112 = test_profile_224.to(
                                self.device), test_profile_112.to(self.device)
                            test_feature_p = self.expert_net.get_feature(
                                test_profile_112)
                            test_gen_p = self.decoder(
                                self.body(test_feature_p))
                            show_list.append(test_profile_224[0])
                            show_list.append(test_gen_p[0])

                        A2B = np.concatenate(
                            (A2B,
                             np.concatenate(
                                 (RGB2BGR(tensor2numpy(denorm(show_list[0]))),
                                  RGB2BGR(tensor2numpy(denorm(show_list[1]))),
                                  RGB2BGR(tensor2numpy(denorm(show_list[2]))),
                                  RGB2BGR(tensor2numpy(denorm(show_list[3])))),
                                 0)), 1)

                    cv2.imwrite(
                        os.path.join("results", self.result_name, 'img',
                                     'A2B_%07d.png' % step), A2B * 255.0)
                    self.body.train(), self.decoder.train(), self.dis.train()

                if step % self.save_freq == 0:
                    self.save(
                        os.path.join("results", self.result_name, "model"),
                        step)

                if step % 1000 == 0:
                    params = {}
                    params['body'] = self.body.state_dict()
                    params['decoder'] = self.decoder.state_dict()
                    params['dis'] = self.dis.state_dict()
                    torch.save(
                        params,
                        os.path.join("results", self.result_name,
                                     self.model_name + "_params_latest.pt"))

    def load(self, dir, step):
        params = torch.load(
            os.path.join(dir, self.model_name + '_params_%07d.pt' % step))
        self.body.load_state_dict(params['body'])
        self.decoder.load_state_dict(params['decoder'])
        self.dis.load_state_dict(params['dis'])

    def save(self, dir, step):
        params = {}
        params['body'] = self.body.state_dict()
        params['decoder'] = self.decoder.state_dict()
        params['dis'] = self.dis.state_dict()
        torch.save(
            params,
            os.path.join(dir, self.model_name + '_params_%07d.pt' % step))

    def demo(self):
        try:
            front_224, front_112 = front_iter.next()
            if front_224.shape[0] != self.batch_size:
                raise Exception
        except:
            front_iter = iter(self.front_loader)
            front_224, front_112 = front_iter.next()
        try:
            profile_224, profile_112 = profile_iter.next()
            if profile_224.shape[0] != self.batch_size:
                raise Exception
        except:
            profile_iter = iter(self.profile_loader)
            profile_224, profile_112 = profile_iter.next()

        profile_224, front_224, profile_112, front_112 = profile_224.to(
            self.device), front_224.to(self.device), profile_112.to(
                self.device), front_112.to(self.device)
        D_face, D_eye, D_nose, D_mouth, D_map = self.dis(profile_224)
        '''
        print("D_face.shape:", D_face.shape)
        print("D_eye.shape:", D_eye.shape)
        print("D_nose.shape:", D_nose.shape)
        print("D_mouth.shape:", D_mouth.shape)
        '''
        cv2.imwrite("profile.jpg",
                    cv2.cvtColor(tensor2im(profile_112), cv2.COLOR_BGR2RGB))
        cv2.imwrite("front.jpg",
                    cv2.cvtColor(tensor2im(front_112), cv2.COLOR_BGR2RGB))
        feature = self.expert_net.get_feature(profile_224)
        print(feature.shape)
        '''
Exemple #29
0
class Solver(object):
    def __init__(self, config, data_loader):
        self.generator = None
        self.discriminator = None
        self.g_optimizer = None
        self.d_optimizer = None
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.z_dim = config.z_dim
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.image_size = config.image_size
        self.data_loader = data_loader
        self.num_epochs = config.num_epochs
        self.batch_size = config.batch_size
        self.sample_size = config.sample_size
        self.lr = config.lr
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.sample_path = config.sample_path
        self.model_path = config.model_path
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        self.generator = Generator(z_dim=self.z_dim,
                                   image_size=self.image_size,
                                   conv_dim=self.g_conv_dim)
        self.discriminator = Discriminator(image_size=self.image_size,
                                           conv_dim=self.d_conv_dim)
        self.g_optimizer = optim.Adam(self.generator.parameters(), self.lr,
                                      [self.beta1, self.beta2])
        self.d_optimizer = optim.Adam(self.discriminator.parameters(), self.lr,
                                      [self.beta1, self.beta2])

        if torch.cuda.is_available():
            self.generator.cuda()
            self.discriminator.cuda()

    def to_variable(self, x):
        """Convert tensor to variable."""
        if torch.cuda.is_available():
            x = x.cuda()
        return Variable(x)

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.discriminator.zero_grad()
        self.generator.zero_grad()

    def denorm(self, x):
        """Convert range (-1, 1) to (0, 1)"""
        out = (x + 1) / 2
        return out.clamp(0, 1)

    def train(self):
        """Train generator and discriminator."""
        fixed_noise = self.to_variable(torch.randn(self.batch_size,
                                                   self.z_dim))
        total_step = len(self.data_loader)
        for epoch in range(self.num_epochs):
            for i, images in enumerate(self.data_loader):

                # ===================== Train D =====================#
                images = self.to_variable(images)
                batch_size = images.size(0)
                noise = self.to_variable(torch.randn(batch_size, self.z_dim))

                # Train D to recognize real images as real.
                outputs = self.discriminator(images)
                real_loss = torch.mean(
                    (outputs - 1)**2
                )  # L2 loss instead of Binary cross entropy loss (this is optional for stable training)

                # Train D to recognize fake images as fake.
                fake_images = self.generator(noise)
                outputs = self.discriminator(fake_images)
                fake_loss = torch.mean(outputs**2)

                # Backprop + optimize
                d_loss = real_loss + fake_loss
                self.reset_grad()
                d_loss.backward()
                self.d_optimizer.step()

                # ===================== Train G =====================#
                noise = self.to_variable(torch.randn(batch_size, self.z_dim))

                # Train G so that D recognizes G(z) as real.
                fake_images = self.generator(noise)
                outputs = self.discriminator(fake_images)
                g_loss = torch.mean((outputs - 1)**2)

                # Backprop + optimize
                self.reset_grad()
                g_loss.backward()
                self.g_optimizer.step()

                # print the log info
                if (i + 1) % self.log_step == 0:
                    print(
                        'Epoch [%d/%d], Step[%d/%d], d_real_loss: %.4f, '
                        'd_fake_loss: %.4f, g_loss: %.4f' %
                        (epoch + 1, self.num_epochs, i + 1, total_step,
                         real_loss.data[0], fake_loss.data[0], g_loss.data[0]))

                # save the sampled images
                if (i + 1) % self.sample_step == 0:
                    fake_images = self.generator(fixed_noise)
                    torchvision.utils.save_image(
                        self.denorm(fake_images.data),
                        os.path.join(
                            self.sample_path,
                            'fake_samples-%d-%d.png' % (epoch + 1, i + 1)))

            # save the model parameters for each epoch
            g_path = os.path.join(self.model_path,
                                  'generator-%d.pkl' % (epoch + 1))
            d_path = os.path.join(self.model_path,
                                  'discriminator-%d.pkl' % (epoch + 1))
            torch.save(self.generator.state_dict(), g_path)
            torch.save(self.discriminator.state_dict(), d_path)

    def sample(self):

        # Load trained parameters
        g_path = os.path.join(self.model_path,
                              'generator-%d.pkl' % (self.num_epochs))
        d_path = os.path.join(self.model_path,
                              'discriminator-%d.pkl' % (self.num_epochs))
        self.generator.load_state_dict(torch.load(g_path))
        self.discriminator.load_state_dict(torch.load(d_path))
        self.generator.eval()
        self.discriminator.eval()

        # Sample the images
        noise = self.to_variable(torch.randn(self.sample_size, self.z_dim))
        fake_images = self.generator(noise)
        sample_path = os.path.join(self.sample_path, 'fake_samples-final.png')
        torchvision.utils.save_image(self.denorm(fake_images.data),
                                     sample_path,
                                     nrow=12)

        print("Saved sampled images to '%s'" % sample_path)
        for g_iter in range(1):
            # generator
            optimizer_G.zero_grad()
            gen_imgs = torch.cat((img, output), 1)

            loss_G = -torch.mean(D(gen_imgs))
            loss_focal = criterion(output, label)
            loss = loss_focal + loss_G

            loss.backward()
            optimizer_G.step()

            train_loss += loss_focal.item() / trainSize

    G.eval(), D.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            img_v, label_v = batch[0].to(device), batch[1].to(device)

            output_v = G(img_v)
            loss = criterion(output_v, label_v)

            val_loss += loss.item() / valSize

    loss_track.append((train_loss, loss_G, loss_D, val_loss))
    torch.save(loss_track, 'checkpoint_GAN/loss.pth')

    print(
        '[{:4d}/{}], tr_ls: {:.5f}, G_ls: {:.5f}, D_ls: {:.5f}, te_ls: {:.5f}'.
        format(epoch + 1, epoch_num, train_loss, loss_G, loss_D, val_loss))