Ejemplo n.º 1
0
    def train_by_epoch(self, start_epoch=1):
        self.device = get_device(self.model)
        self.model.train()

        start_all = time()
        for epoch_cnt in range(start_epoch, self.opt.max_epoch + 1):
            self.logger.info(f"\n[ Epoch {epoch_cnt} ]")

            start_span = time()
            avg_epoch_loss = self._train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(
                f"word_loss : {avg_epoch_loss:.2f}, time : {time_span:.2f} min"
            )

            if self.validator is not None:
                state_dict = self.validation(epoch_cnt)
            else:
                state_dict = get_state_dict(self.model)

            if epoch_cnt > self.opt.max_epoch / 3:
                self.save_model(epoch_cnt, state_dict,
                                f"epoch_{epoch_cnt}.pth")

        time_all = (time() - start_all) / 3600
        self.logger.info(
            f"\nbest_epoch : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h"
        )
Ejemplo n.º 2
0
    def __init__(self, src_DAMSM_CNN, src_DAMSM_RNN, tgt_DAMSM_RNN, netG, netsD,
                 netG_optimizer, netD_optimizers, train_loader, scaler, opt):
        self.src_DAMSM_CNN = src_DAMSM_CNN
        self.src_DAMSM_RNN = src_DAMSM_RNN
        self.tgt_DAMSM_RNN = tgt_DAMSM_RNN
        self.netG = netG
        self.netsD = netsD
        self.netG_optimizer = netG_optimizer
        self.netD_optimizers = netD_optimizers
        self.train_loader = train_loader
        self.scaler = scaler
        self.opt = opt

        self.save_model_dir = opt.save_model_dir
        self.save_image_dir = opt.save_image_dir
        self.stage_num = opt.stage_num

        self.device = get_device(self.netG)
        self.avg_param_G = copy_params(self.netG)
        self.real_labels = torch.FloatTensor(opt.batch_size).fill_(1).to(self.device)
        self.fake_labels = torch.FloatTensor(opt.batch_size).fill_(0).to(self.device)
        self.match_labels = torch.arange(opt.batch_size).to(self.device)
        self.noise = torch.FloatTensor(opt.batch_size, opt.noise_dim).to(self.device)
        self.fixed_noise = torch.FloatTensor(opt.batch_size, opt.noise_dim).normal_(0, 1).to(self.device)
        self.src_word_embs, self.src_sent_emb, self.src_mask, self.tgt_word_embs, self.tgt_sent_emb, self.tgt_mask = self.get_fixed_embs()

        self.logger = get_logger(opt.save_log_path, overwrite=opt.overwrite)
        self.logger.info(args2string(opt))
Ejemplo n.º 3
0
    def train(self, start_epoch=1):
        self.device = get_device(self.image_encoder)
        self.image_encoder.train()
        self.text_encoder.train()

        lr = self.opt.lr
        start_all = time()
        for epoch_cnt in range(start_epoch, self.opt.max_epoch + 1):
            self.logger.info(f"\n[ Epoch {epoch_cnt} ]")

            start_span = time()
            train_losses = self._train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(
                f"train_w_loss : {train_losses[0]:.2f} {train_losses[1]:.2f} "
                f"train_s_loss : {train_losses[2]:.2f} {train_losses[3]:.2f} "
                f"time : {time_span:.2f}"
            )

            if epoch_cnt > self.opt.max_epoch / 3:
                if self.valid_loader is not None:
                    valid_losses = self.evaluate()
                    self.logger.info(
                        f"valid_w_loss : {valid_losses[0]:.2f} "
                        f"valid_s_loss : {valid_losses[1]:.2f} "
                        f"lr : {lr:.5f}"
                    )
                    if sum(valid_losses) <= self.best_score:
                        self.best_score = sum(valid_losses)
                        self.best_epoch = epoch_cnt
                        self.save_model(epoch_cnt, "best.pth")

                if epoch_cnt % self.opt.save_freq == 0:
                    self.save_model(epoch_cnt, f"epoch_{epoch_cnt}.pth")

            if lr > self.opt.lr / 10.:
                print("reset!!")
                lr *= 0.98
                for param_group in self.image_optimizer.param_groups:
                    param_group['lr'] = lr
                for param_group in self.text_optimizer.param_groups:
                    param_group['lr'] = lr

        time_all = (time() - start_all) / 3600
        self.logger.info(f"\nbest_epoch : {self.best_epoch}, best_score : {self.best_score}, time : {time_all:.2f} h")
    def __init__(self, src_DAMSM_RNN, tgt_DAMSM_RNN, netG, MMT, dataloader,
                 opt):
        self.src_DAMSM_RNN = src_DAMSM_RNN
        self.tgt_DAMSM_RNN = tgt_DAMSM_RNN
        self.netG = netG
        self.MMT = MMT

        self.dataloader = dataloader
        self.device = get_device(self.netG)
        self.opt = opt

        self.model_connector = ModelConnector(
            opt.batch_size,
            opt.words_limit,
            MMT_id2word=dataloader.dataset.MMT_index2word,
            T2I_word2id=dataloader.dataset.tgt_word2index,
            bpe=opt.bpe,
        )
Ejemplo n.º 5
0
    def __init__(self, image_encoder, text_encoder, train_loader,
                 image_optimizer, text_optimizer, scaler, opt, valid_loader=None):
        self.image_encoder = image_encoder
        self.text_encoder = text_encoder
        self.train_loader = train_loader
        self.image_optimizer = image_optimizer
        self.text_optimizer = text_optimizer
        self.scaler = scaler
        self.opt = opt
        self.valid_loader = valid_loader

        self.device = get_device(self.image_encoder)
        self.logger = get_logger(opt.save_log_path, overwrite=opt.overwrite)
        self.logger.info(args2string(opt))

        self.labels = torch.arange(opt.batch_size).to(self.device)
        
        self.best_epoch = 0
        self.best_score = 99999.
Ejemplo n.º 6
0
    def __init__(self,
                 model,
                 train_loader,
                 optimizer,
                 scaler,
                 scheduler,
                 opt,
                 validator=None):
        self.model = model
        self.train_loader = train_loader
        self.optimizer = optimizer
        self.scaler = scaler
        self.scheduler = scheduler
        self.opt = opt
        self.validator = validator

        self.device = get_device(self.model)
        self.logger = get_logger(opt.save_log_path, overwrite=opt.overwrite)
        self.logger.info(args2string(opt))

        self.best_bleu_score = 0.
        self.best_cnt = 0
Ejemplo n.º 7
0
    def train(self, start_cnt):
        self.device = get_device(self.MNMT)

        start_all = time()
        for epoch_cnt in range(start_cnt, self.opt.max_epoch + 1):
            self.logger.info(f"\n[ Epoch {epoch_cnt} ]")

            # --- train MNMT ---
            start_span = time()
            backup_para = copy_params(self.netG)

            load_params(self.netG, self.avg_param_G)
            logs = self.MNMT_train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(f"{logs}, time : {time_span:.2f} min")

            # --- valid MNMT ---
            if self.validator is not None:
                state_dict = self.validation(epoch_cnt)
            else:
                state_dict = get_state_dict(self.MNMT)
            if self.stop_cnt == self.opt.early_stop:
                break
            self.save_models(epoch_cnt, state_dict, f"epoch_{epoch_cnt}.pth")

            # --- train T2I ---
            start_span = time()
            load_params(self.netG, backup_para)
            for _ in range(self.opt.T2I_per_MNMT):
                D_logs, G_logs = self.T2I_train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(f"{D_logs}\n{G_logs}\ntime : {time_span:.2f} min")

        time_all = (time() - start_all) / 3600
        self.logger.info(
            f"\nbest_epoch : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h"
        )
Ejemplo n.º 8
0
    def train(self, start_epoch=1):
        self.device = get_device(self.netG)
        self.netG.train()
        for i in range(self.stage_num):
            self.netsD[i].train()

        start_all = time()
        for epoch_cnt in range(start_epoch, self.opt.max_epoch + 1):
            self.logger.info(f"\n[ Epoch {epoch_cnt} ]")

            start_span = time()
            D_logs, G_logs = self.train_epoch()
            time_span = (time() - start_span) / 60
            self.logger.info(f"{D_logs}\n{G_logs}\ntime : {time_span:.2f} min")

            if epoch_cnt % self.opt.display_freq == 0:
                self.save_fixed_images(epoch_cnt)

            if epoch_cnt % self.opt.save_freq == 0:
                self.save_model(epoch_cnt, f"epoch_{epoch_cnt}.pth")

        self.save_model(epoch_cnt, f"epoch_{epoch_cnt}.pth")
        time_all = (time() - start_all) / 3600
        self.logger.info(f"time : {time_all:.2f} h")
Ejemplo n.º 9
0
    def __init__(self,
                 MNMT,
                 src_DAMSM_CNN,
                 src_DAMSM_RNN,
                 tgt_DAMSM_RNN,
                 netG,
                 netsD,
                 MNMT_optimizer,
                 netG_optimizer,
                 netD_optimizers,
                 DAMSM_optimizer,
                 MNMT_loader,
                 T2I_loader,
                 scaler,
                 scheduler,
                 opt,
                 validator=None):
        self.MNMT = MNMT
        self.src_DAMSM_CNN = src_DAMSM_CNN
        self.src_DAMSM_RNN = src_DAMSM_RNN
        self.tgt_DAMSM_RNN = tgt_DAMSM_RNN
        self.netG = netG
        self.netsD = netsD
        self.MNMT_optimizer = MNMT_optimizer
        self.netG_optimizer = netG_optimizer
        self.netD_optimizers = netD_optimizers
        self.DAMSM_optimizer = DAMSM_optimizer
        self.MNMT_loader = MNMT_loader
        self.T2I_loader = T2I_loader
        self.scaler = scaler
        self.scheduler = scheduler
        self.opt = opt
        self.validator = validator

        self.save_model_dir = opt.save_model_dir
        self.save_image_dir = opt.save_image_dir
        self.stage_num = opt.stage_num

        self.device = get_device(self.netG)
        self.avg_param_G = copy_params(self.netG)
        self.real_labels = torch.FloatTensor(opt.T2I_batch_size).fill_(1).to(
            self.device)
        self.fake_labels = torch.FloatTensor(opt.T2I_batch_size).fill_(0).to(
            self.device)
        self.match_labels = torch.LongTensor(range(opt.T2I_batch_size)).to(
            self.device)
        self.T2I_noise = torch.FloatTensor(opt.T2I_batch_size,
                                           opt.noise_dim).to(self.device)
        self.MNMT_noise = torch.FloatTensor(opt.MNMT_batch_size,
                                            opt.noise_dim).to(self.device)
        self.model_connector = ModelConnector(
            opt.T2I_batch_size,
            opt.train_words_limit,
            MNMT_id2word=MNMT_loader.dataset.tgt_index2word,
            T2I_word2id=T2I_loader.dataset.tgt_word2index,
            bpe=opt.bpe,
        )

        self.logger = get_logger(opt.save_log_path, overwrite=opt.overwrite)
        self.logger.info(args2string(opt))

        self.best_bleu_score = 0.
        self.best_cnt = 0
        self.stop_cnt = 0
    def sampling(self):
        self.src_DAMSM_RNN.eval()
        self.tgt_DAMSM_RNN.eval()
        self.netG.eval()
        self.MMT.eval()

        device = get_device(self.netG)

        save_dir = self.opt.save_image_dir
        noise = torch.FloatTensor(self.opt.batch_size,
                                  self.opt.noise_dim).to(device)

        for _ in range(1):  # (cfg.TEXT.CAPTIONS_PER_IMAGE):
            pbar = tqdm(self.dataloader, ascii=True, mininterval=0.5, ncols=90)
            for imgs, MMT_src_text, MMT_src_pos, T2I_src_text, T2I_src_len, filenames in pbar:
                T2I_tgt_text, T2I_tgt_len = self.generate_T2I_tgt_text(
                    imgs[-1], MMT_src_text, MMT_src_pos)

                batch_size = MMT_src_text.size(0)
                T2I_src_text = T2I_src_text.to(self.device)
                T2I_src_len = T2I_src_len.to(self.device)
                T2I_tgt_text = T2I_tgt_text.to(self.device)
                T2I_tgt_len = T2I_tgt_len.to(self.device)

                ##########################################################
                # (1) Prepare training data and Compute text embeddings
                ##########################################################
                src_words_embs, src_sent_emb = self.src_DAMSM_RNN(
                    T2I_src_text, T2I_src_len)
                src_mask = (T2I_src_text == Constants.PAD)
                num_words = src_words_embs.size(2)
                if src_mask.size(1) > num_words:
                    src_mask = src_mask[:, :num_words]

                tgt_words_embs, tgt_sent_emb = self.tgt_DAMSM_RNN(
                    T2I_tgt_text, T2I_tgt_len)
                tgt_mask = (T2I_tgt_text == Constants.PAD)
                num_words = tgt_words_embs.size(2)
                if tgt_mask.size(1) > num_words:
                    tgt_mask = tgt_mask[:, :num_words]

                ##########################################################
                # (2) Generate fake images
                ##########################################################
                noise.data.normal_(0, 1)
                fake_imgs, _, _, _ = self.netG(noise[:batch_size],
                                               src_words_embs, tgt_words_embs,
                                               src_sent_emb, tgt_sent_emb,
                                               src_mask, tgt_mask)

                ##########################################################
                # (3) Save images
                ##########################################################
                for j in range(batch_size):
                    file_path = '%s/%s' % (save_dir, filenames[j])
                    k = -1
                    # for k in range(len(fake_imgs)):
                    im = fake_imgs[k][j].data.cpu().numpy()
                    # [-1, 1] --> [0, 255]
                    im = (im + 1.0) * 127.5
                    im = im.astype(np.uint8)
                    im = np.transpose(im, (1, 2, 0))
                    im = Image.fromarray(im)
                    im.save(file_path)
Ejemplo n.º 11
0
    def sampling_from_image_id(self, image_id):
        self.src_DAMSM_RNN.eval()
        self.tgt_DAMSM_RNN.eval()
        self.netG.eval()

        device = get_device(self.netG)

        image_name = image_id + '.jpg'
        dataset = self.dataloader.dataset
        img_insts = dataset.img_insts
        inst_id = []
        for i in range(len(img_insts)):
            if img_insts[i].find(image_name) != -1:
                inst_id.append(i)

        if len(inst_id) != 1:
            print('[Error]')
            print(image_name)
            exit()

        inst_id = inst_id[0]

        save_dir = self.opt.save_image_dir / image_id
        mkdirs(save_dir)
        noise = torch.FloatTensor(1, self.opt.noise_dim).to(device)
        for i in range(self.opt.noise_num):
            src_inst = dataset.T2I_src_insts[inst_id]
            if dataset.bpe_index2word is not None:
                src_inst = [dataset.bpe_index2word[id] for id in src_inst]
                src_inst = ' '.join(src_inst)
                src_inst = self.bpe_re.sub('', src_inst)
                src_inst = src_inst.split(' ')
                src_inst = [
                    dataset.src_word2index.get(word, Constants.UNK)
                    for word in src_inst
                ]
            tgt_inst = dataset.T2I_tgt_insts[inst_id]
            tgt_inst = [
                dataset.tgt_word2index.get(word, Constants.UNK)
                for word in tgt_inst
            ]
            GAN_src_text, GAN_src_len = dataset.get_T2I_text(src_inst)
            GAN_tgt_text, GAN_tgt_len = dataset.get_T2I_text(tgt_inst)
            GAN_src_text = torch.LongTensor(GAN_src_text).unsqueeze(0).to(
                device)
            GAN_src_len = torch.LongTensor([GAN_src_len]).to(device)
            GAN_tgt_text = torch.LongTensor(GAN_tgt_text).unsqueeze(0).to(
                device)
            GAN_tgt_len = torch.LongTensor([GAN_tgt_len]).to(device)

            src_words_embs, src_sent_emb = self.src_DAMSM_RNN(
                GAN_src_text, GAN_src_len)
            src_mask = (GAN_src_text == 0)
            num_words = src_words_embs.size(2)
            if src_mask.size(1) > num_words:
                src_mask = src_mask[:, :num_words]

            tgt_words_embs, tgt_sent_emb = self.tgt_DAMSM_RNN(
                GAN_tgt_text, GAN_tgt_len)
            tgt_mask = (GAN_tgt_text == 0)
            num_words = tgt_words_embs.size(2)
            if tgt_mask.size(1) > num_words:
                tgt_mask = tgt_mask[:, :num_words]

            noise.data.normal_(0, 1)
            fake_imgs, _, _, _ = self.netG(noise, src_words_embs,
                                           tgt_words_embs, src_sent_emb,
                                           tgt_sent_emb, src_mask, tgt_mask)

            file_path = '%s/%s_fake%d.png' % (save_dir, image_id, i)
            k = -1
            # for k in range(len(fake_imgs)):
            im = fake_imgs[k][0].data.cpu().numpy()
            # [-1, 1] --> [0, 255]
            im = (im + 1.0) * 127.5
            im = im.astype(np.uint8)
            im = np.transpose(im, (1, 2, 0))
            im = Image.fromarray(im)
            im.save(file_path)
Ejemplo n.º 12
0
    def train_by_step(self, start_step=1):
        self.device = get_device(self.model)

        self.model.train()
        self.optimizer.zero_grad()
        data_iter = iter(self.train_loader)
        checkpoint_interval = 1500
        checkpoint_cnt = 1
        checkpoint_loss = 0.

        step_cnt = start_step
        iter_num = (self.opt.max_step - start_step +
                    1) * self.opt.grad_accumulation
        start_all = time()
        start_span = time()
        pbar = tqdm(range(iter_num), ncols=90, mininterval=0.5, ascii=True)
        for _ in pbar:
            try:
                train_datas = data_iter.next()
            except StopIteration:
                data_iter = iter(self.train_loader)
                train_datas = data_iter.next()

            with autocast(self.opt.use_amp):
                loss, batch_size = self.cal_loss(*train_datas)
            checkpoint_loss += loss.item()
            loss /= self.opt.grad_accumulation

            self.scaler.scale(loss).backward()

            if checkpoint_cnt % self.opt.grad_accumulation == 0:
                self.scaler.unscale_(self.optimizer)
                clip_grad_norm_(self.model.parameters(), self.opt.max_norm)
                self.scheduler.update_lr()
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad()

                avg_checkpoint_loss = checkpoint_loss / checkpoint_cnt
                pbar.set_description(
                    f"\r[{step_cnt}/{self.opt.max_step}] " \
                    f"word_loss : {avg_checkpoint_loss:.2f}, batch_size : {batch_size:<5}"
                )

                if step_cnt % checkpoint_interval == 0:
                    print()
                    time_span = (time() - start_span) / 60
                    self.logger.info(f"\n[ Step {step_cnt} ]")
                    self.logger.info(
                        f"word_loss : {avg_checkpoint_loss:.2f}, time : {time_span:.2f} min"
                    )

                    if self.validator is not None:
                        state_dict = self.validation(step_cnt)
                    else:
                        state_dict = get_state_dict(self.model)

                    if step_cnt > self.opt.max_step / 3:
                        self.save_model(step_cnt, state_dict,
                                        f"step_{step_cnt}.pth")

                    start_span = time()
                    checkpoint_cnt = 0
                    checkpoint_loss = 0.
                    print()

                step_cnt += 1
            checkpoint_cnt += 1

        step_cnt = step_cnt - 1
        if step_cnt % checkpoint_interval != 0:
            self.logger.info(f"\n[ Step {step_cnt} ]")
            self.logger.info(f"word_loss : {avg_checkpoint_loss:.2f}")
            if self.validator is not None:
                state_dict = self.validation(step_cnt)
            else:
                state_dict = get_state_dict(self.model)
            self.save_model(step_cnt, state_dict, f"step_{step_cnt}.pth")

        time_all = (time() - start_all) / 3600
        self.logger.info(
            f"\nbest_step : {self.best_cnt}, best_score : {self.best_bleu_score}, time : {time_all:.2f} h"
        )