Exemple #1
0
    def train(self):
        print('train')

        n_items = None
        train_loader = get_numpy2d_dataset(
            base_dir=self.train_data_dir,
            num_processes=self.batch_size,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.target_size,
            drop_last=True,
            n_items=n_items,
            functions_dict=self.dataset_functions,
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.train_data_dir,
            num_processes=self.batch_size // 2,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.target_size,
            drop_last=True,
            n_items=n_items,
            functions_dict=self.dataset_functions,
        )

        for epoch in range(self.n_epochs):
            self.model.train()
            train_loss = 0

            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                loss = self.train_model(data)

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item():.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    # if self.logger is not None:
                    # self.tx.l[0].show_image_grid(inpt, name="Input", image_args={"normalize": True})
                    # self.tx.l[0].show_image_grid(inpt_rec, name="Reconstruction", image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.6f}"
            )

            # validate
            self.model.eval()

            val_loss = 0
            data_loader_ = tqdm(enumerate(val_loader))
            data_loader_.set_description_str("Validating")
            for _, data in data_loader_:
                loss = self.eval_model(data)
                val_loss += loss.item()

            self.tx.add_result(val_loss / len(val_loader),
                               name="Val-Loss",
                               tag="Losses",
                               counter=(epoch + 1) * len(train_loader))
            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.6f}"
            )

        self.tx.save_model(self.model, "model")
        time.sleep(2)
Exemple #2
0
    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.input_shape[2],
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=8,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.input_shape[2],
        )

        for epoch in range(self.n_epochs):

            ### Train
            self.model.train()

            train_loss = 0
            print("\nStart epoch ", epoch)
            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                inpt = data.to(self.device)

                self.optimizer.zero_grad()
                inpt_rec = self.model(inpt)

                loss = torch.mean(torch.pow(inpt - inpt_rec, 2))
                loss.backward()
                self.optimizer.step()

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item() / len(inpt):.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    if self.logger is not None:
                        self.tx.l[0].show_image_grid(
                            inpt, name="Input", image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            inpt_rec,
                            name="Reconstruction",
                            image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}"
            )

            ### Validate
            self.model.eval()

            val_loss = 0
            with torch.no_grad():
                data_loader_ = tqdm(enumerate(val_loader))
                data_loader_.set_description_str("Validating")
                for i, data in data_loader_:
                    inpt = data.to(self.device)
                    inpt_rec = self.model(inpt)

                    loss = torch.mean(torch.pow(inpt - inpt_rec, 2))
                    val_loss += loss.item()

                self.tx.add_result(val_loss / len(val_loader),
                                   name="Val-Loss",
                                   tag="Losses",
                                   counter=(epoch + 1) * len(train_loader))

            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}"
            )

        self.tx.save_model(self.model, "ae_final")

        time.sleep(10)
Exemple #3
0
    def train(self):
        n_items = None
        train_loader = get_numpy2d_dataset(
            base_dir=self.train_data_dir,
            num_processes=self.batch_size,
            pin_memory=True,
            batch_size=self.batch_size,
            mode="all",
            # target_size=self.target_size,
            drop_last=False,
            n_items=n_items,
            functions_dict=self.dataset_functions,
        )
        # val_loader = get_numpy2d_dataset(
        #     base_dir=self.test_data_dir,
        #     num_processes=self.batch_size // 2,
        #     pin_memory=True,
        #     batch_size=self.batch_size,
        #     mode="all",
        #     # target_size=self.target_size,
        #     drop_last=False,
        #     n_items=n_items,
        #     functions_dict=self.dataset_functions,
        # )
        train_loader = DataPreFetcher(train_loader)
        # val_loader = DataPreFetcher(val_loader)

        for epoch in range(self.n_epochs):
            print('train')
            self.model.train()
            train_loss = 0

            data_loader_ = tqdm(enumerate(train_loader))
            # data_loader_ = enumerate(train_loader)
            for batch_idx, data in data_loader_:
                # data = data.cuda()
                loss, input, out = self.train_model(data)

                train_loss += loss.item()
                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch + 1} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item():.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx

                    # tensorboard记录
                    self.tx.add_result(loss.item(),
                                       name="Train-Loss",
                                       tag="Losses",
                                       counter=cnt)

                    # if self.logger is not None:
                    #     self.tx.l[0].show_image_grid(input, name="Input", image_args={"normalize": True})
                    #     self.tx.l[0].show_image_grid(out, name="Reconstruction", image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch + 1} Average loss: {train_loss / len(train_loader):.6f}"
            )

            # print('validate')
            # self.model.eval()
            # val_loss = 0

            # data_loader_ = tqdm(enumerate(val_loader))
            # data_loader_.set_description_str("Validating")
            # for _, data in data_loader_:
            #     loss = self.eval_model(data)
            #     val_loss += loss.item()

            # self.tx.add_result(
            #     val_loss / len(val_loader), name="Val-Loss", tag="Losses", counter=(epoch + 1) * len(train_loader))
            # print(f"====> Epoch: {epoch + 1} Validation loss: {val_loss / len(val_loader):.6f}")

            # if (epoch + 1) % self.save_per_epoch == 0:
            if (epoch + 1) > self.n_epochs - 5:
                self.save_model(epoch + 1)

        time.sleep(2)
Exemple #4
0
    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.size,
            slice_offset=10,
        )

        print("Training GAN...")
        for epoch in range(self.n_epochs):
            # for epoch in range(0):

            data_loader_ = tqdm(enumerate(train_loader))
            for i, batch in data_loader_:
                batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01

                real_imgs = batch.to(self.device)

                # ---------------------
                #  Train Discriminator
                # ---------------------
                # disc_cost = []
                # w_dist = []
                if i % self.critic_iters == 0:
                    self.optimizer_G.zero_grad()
                    self.optimizer_D.zero_grad()

                    batch_size_curr = real_imgs.shape[0]

                    self.z.normal_()

                    fake_imgs = self.gen(self.z[:batch_size_curr])

                    real_validity = self.dis(real_imgs)
                    fake_validity = self.dis(fake_imgs)

                    gradient_penalty = self.calc_gradient_penalty(
                        self.dis,
                        real_imgs,
                        fake_imgs,
                        batch_size_curr,
                        self.size,
                        self.device,
                        self.gp_lambda,
                        n_image_channels=self.n_image_channels,
                    )

                    d_loss = -torch.mean(real_validity) + torch.mean(
                        fake_validity) + self.gp_lambda * gradient_penalty
                    d_loss.backward()
                    self.optimizer_D.step()

                    # disc_cost.append(d_loss.item())
                    w_dist = (-torch.mean(real_validity) +
                              torch.mean(fake_validity)).item()

                # -----------------
                #  Train Generator
                # -----------------
                # gen_cost = []
                if i % self.gen_iters == 0:
                    self.optimizer_G.zero_grad()
                    self.optimizer_D.zero_grad()

                    batch_size_curr = self.batch_size

                    fake_imgs = self.gen(self.z)

                    fake_validity = self.dis(fake_imgs)
                    g_loss = -torch.mean(fake_validity)

                    g_loss.backward()
                    self.optimizer_G.step()

                    # gen_cost.append(g_loss.item())

                if i % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{i}/{len(train_loader)} "
                        f" ({100.0 * i / len(train_loader):.0f}%)] Dis: "
                        f"{d_loss.item() / batch_size_curr:.6f} vs Gen: "
                        f"{g_loss.item() / batch_size_curr:.6f} (W-Dist: {w_dist / batch_size_curr:.6f})"
                    )
                    data_loader_.set_description_str(status_str)
                    # print(f"[Epoch {epoch}/{self.n_epochs}] [Batch {i}/{len(train_loader)}]")

                    # print(d_loss.item(), g_loss.item())
                    cnt = epoch * len(train_loader) + i

                    self.tx.add_result(d_loss.item(),
                                       name="trainDisCost",
                                       tag="DisVsGen",
                                       counter=cnt)
                    self.tx.add_result(g_loss.item(),
                                       name="trainGenCost",
                                       tag="DisVsGen",
                                       counter=cnt)
                    self.tx.add_result(w_dist,
                                       "wasserstein_distance",
                                       counter=cnt)

                    self.tx.l[0].show_image_grid(
                        fake_imgs.reshape(batch_size_curr,
                                          self.n_image_channels, self.size,
                                          self.size),
                        "GeneratedImages",
                        image_args={"normalize": True},
                    )

        self.tx.save_model(self.dis, "dis_final")
        self.tx.save_model(self.gen, "gen_final")

        self.gen.train(True)
        self.dis.train(True)

        if not self.use_encoder:
            time.sleep(10)
            return

        weight_features = self.enocoder_feature_weight
        weight_disc = self.encoder_discr_weight
        print("Training Encoder...")
        for epoch in range(self.n_epochs // 2):
            data_loader_ = tqdm(enumerate(train_loader))
            for i, batch in data_loader_:
                batch = batch * 2 - 1 + torch.randn_like(batch) * 0.01
                real_img = batch.to(self.device)
                batch_size_curr = real_img.shape[0]

                self.optimizer_G.zero_grad()
                self.optimizer_D.zero_grad()
                self.optimizer_E.zero_grad()

                z = self.enc(real_img)
                recon_img = self.gen(z)

                _, img_feats = self.dis.forward_last_feature(real_img)
                disc_loss, recon_feats = self.dis.forward_last_feature(
                    recon_img)

                recon_img = recon_img.reshape(batch_size_curr,
                                              self.n_image_channels, self.size,
                                              self.size)
                loss_img = self.mse(real_img, recon_img)
                loss_feat = self.mse(img_feats, recon_feats) * weight_features
                disc_loss = -torch.mean(disc_loss) * weight_disc

                loss = loss_img + loss_feat + disc_loss

                loss.backward()
                self.optimizer_E.step()

                if i % self.print_every_iter == 0:
                    status_str = (
                        f"[Epoch {epoch}/{self.n_epochs // 2}] [Batch {i}/{len(train_loader)}] Loss:{loss:.06f}"
                    )
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + i
                    self.tx.add_result(loss.item(),
                                       name="EncoderLoss",
                                       counter=cnt)

                    self.tx.l[0].show_image_grid(
                        real_img.reshape(batch_size_curr,
                                         self.n_image_channels, self.size,
                                         self.size),
                        "RealImages",
                        image_args={"normalize": True},
                    )
                    self.tx.l[0].show_image_grid(
                        recon_img.reshape(batch_size_curr,
                                          self.n_image_channels, self.size,
                                          self.size),
                        "ReconImages",
                        image_args={"normalize": True},
                    )

        self.tx.save_model(self.enc, "enc_final")
        self.enc.train(False)

        time.sleep(10)
Exemple #5
0
    def train(self):

        train_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=16,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="train",
            target_size=self.input_shape[2],
        )
        val_loader = get_numpy2d_dataset(
            base_dir=self.data_dir,
            num_processes=8,
            pin_memory=False,
            batch_size=self.batch_size,
            mode="val",
            target_size=self.input_shape[2],
        )

        for epoch in range(self.n_epochs):

            self.model.train()
            train_loss = 0

            print("Start epoch")
            data_loader_ = tqdm(enumerate(train_loader))
            for batch_idx, data in data_loader_:
                data = data * 2 - 1
                self.optimizer.zero_grad()

                inpt = data.to(self.device)

                ### VAE Part
                loss_vae = 0
                if self.ce_factor < 1:
                    x_rec_vae, z_dist, = self.model(inpt)

                    kl_loss = 0
                    if self.beta > 0:
                        kl_loss = self.kl_loss_fn(z_dist) * self.beta
                    rec_loss_vae = self.rec_loss_fn(x_rec_vae, inpt)
                    loss_vae = kl_loss + rec_loss_vae * self.theta

                ### CE Part
                loss_ce = 0
                if self.ce_factor > 0:

                    ce_tensor = get_square_mask(
                        data.shape,
                        square_size=(0, np.max(self.input_shape[2:]) // 2),
                        noise_val=(torch.min(data).item(),
                                   torch.max(data).item()),
                        n_squares=(0, 3),
                    )
                    ce_tensor = torch.from_numpy(ce_tensor).float()
                    inpt_noisy = torch.where(ce_tensor != 0, ce_tensor, data)

                    inpt_noisy = inpt_noisy.to(self.device)
                    x_rec_ce, _ = self.model(inpt_noisy)
                    rec_loss_ce = self.rec_loss_fn(x_rec_ce, inpt)
                    loss_ce = rec_loss_ce

                loss = (1.0 -
                        self.ce_factor) * loss_vae + self.ce_factor * loss_ce

                if self.use_geco and self.ce_factor < 1:
                    g_goal = 0.1
                    g_lr = 1e-4
                    self.vae_loss_ema = (
                        1.0 - 0.9) * rec_loss_vae + 0.9 * self.vae_loss_ema
                    self.theta = self.geco_beta_update(self.theta,
                                                       self.vae_loss_ema,
                                                       g_goal,
                                                       g_lr,
                                                       speedup=2)

                if torch.isnan(loss):
                    print("A wild NaN occurred")
                    continue

                loss.backward()
                self.optimizer.step()
                train_loss += loss.item()

                if batch_idx % self.print_every_iter == 0:
                    status_str = (
                        f"Train Epoch: {epoch} [{batch_idx}/{len(train_loader)} "
                        f" ({100.0 * batch_idx / len(train_loader):.0f}%)] Loss: "
                        f"{loss.item() / len(inpt):.6f}")
                    data_loader_.set_description_str(status_str)

                    cnt = epoch * len(train_loader) + batch_idx

                    if self.ce_factor < 1:
                        self.tx.l[0].show_image_grid(
                            inpt,
                            name="Input-VAE",
                            image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            x_rec_vae,
                            name="Output-VAE",
                            image_args={"normalize": True})

                        if self.beta > 0:
                            self.tx.add_result(torch.mean(kl_loss).item(),
                                               name="Kl-loss",
                                               tag="Losses",
                                               counter=cnt)
                        self.tx.add_result(torch.mean(rec_loss_vae).item(),
                                           name="Rec-loss",
                                           tag="Losses",
                                           counter=cnt)
                        self.tx.add_result(loss_vae.item(),
                                           name="Train-loss",
                                           tag="Losses",
                                           counter=cnt)

                    if self.ce_factor > 0:
                        self.tx.l[0].show_image_grid(
                            inpt_noisy,
                            name="Input-CE",
                            image_args={"normalize": True})
                        self.tx.l[0].show_image_grid(
                            x_rec_ce,
                            name="Output-CE",
                            image_args={"normalize": True})

            print(
                f"====> Epoch: {epoch} Average loss: {train_loss / len(train_loader):.4f}"
            )

            self.model.eval()

            val_loss = 0
            with torch.no_grad():
                data_loader_ = tqdm(enumerate(val_loader))
                for i, data in data_loader_:
                    data = data * 2 - 1
                    inpt = data.to(self.device)

                    x_rec, z_dist = self.model(inpt, sample=False)

                    kl_loss = 0
                    if self.beta > 0:
                        kl_loss = self.kl_loss_fn(z_dist) * self.beta
                    rec_loss = self.rec_loss_fn(x_rec, inpt)
                    loss = kl_loss + rec_loss * self.theta

                    val_loss += loss.item()

                self.tx.add_result(val_loss / len(val_loader),
                                   name="Val-Loss",
                                   tag="Losses",
                                   counter=(epoch + 1) * len(train_loader))

            print(
                f"====> Epoch: {epoch} Validation loss: {val_loss / len(val_loader):.4f}"
            )

        self.tx.save_model(self.model, "vae_final")

        time.sleep(10)