def train(self):
        # Train following learning rate schedule
        params = self.params
        while self.current_epoch < params["train_epoch"]:
            epoch_start_time = time.time()

            # TRAIN LOOP
            self.train_loop()

            # TEST LOOP
            self.test_loop()

            # generate test images and save to disk
            if self.current_epoch % params["save_img_every"] == 0:
                helper.show_test(
                    params,
                    self.transform,
                    self.test_loader,
                    self.model_dict['G'],
                    save=
                    f'output/{params["save_root"]}_val_{self.current_epoch}.jpg'
                )
            # save
            if self.current_epoch % params["save_every"] == 0:
                save_str = self.save_state(
                    f'output/{params["save_root"]}_{self.current_epoch}.json')
                tqdm.write(save_str)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            print(f'Epoch Training Training Time: {per_epoch_ptime}')
            [
                print(
                    f'Train {loss}: {helper.mft(self.loss_epoch_dict[loss])}')
                for loss in self.losses
            ]
            [
                print(
                    f'Val {loss}: {helper.mft(self.loss_epoch_dict_test[loss])}'
                ) for loss in self.losses
            ]
            print('\n')
            self.current_epoch += 1

        self.display_history()
        print('Hit End of Learning Schedule!')
Exemplo n.º 2
0
    def train(self):
        # Train
        params = self.params
        for epoch in range(params["train_epoch"]):
            epoch_start_time = time.time()

            # TRAIN LOOP
            self.train_loop()

            # save preview image, or weights and loss graph
            if self.current_epoch % params['save_img_every'] == 0:
                helper.show_test(self.model_dict['G'],
                                 Variable(self.preview_noise),
                                 self.transform,
                                 save='output/' + str(params["save_root"]) +
                                 '_' + str(self.current_epoch) + '.jpg')
            if self.current_epoch % params['save_every'] == 0:
                self.display_history()
                save_str = self.save_state('output/' +
                                           str(params["save_root"]) + '_' +
                                           str(self.current_epoch) + '.json')
                print(save_str)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            print(f'Epoch Training Training Time: {per_epoch_ptime}')
            [
                print(
                    f'Train {loss}: {helper.mft(self.loss_epoch_dict[loss])}')
                for loss in self.losses
            ]
            print('\n')
            self.current_epoch += 1

        self.display_history()
        print('Hit End of Learning Schedule!')
    def train(self):
        # Train
        params = self.params
        for epoch in range(params["train_epoch"]):

            # clear last epopchs losses
            for loss in self.losses:
                self.loss_epoch_dict[loss] = []

            print(
                f"Sched Iter:{self.current_iter}, Sched Epoch:{self.current_epoch}"
            )
            [
                print(
                    f"Learning Rate({opt}): {self.opt_dict[opt].param_groups[0]['lr']}"
                ) for opt in self.opt_dict.keys()
            ]

            self.loop_iter = 0
            epoch_start_time = time.time()

            for (real_data) in tqdm(self.train_loader):
                real = real_data.cuda()
                # add some noise
                real = (real * .9) + (
                    .1 * torch.FloatTensor(real.shape).normal_(0, .5).cuda())

                # DREAM #
                self.set_grad_req(d=False, g=True)
                fake = self.dream_on_mesh(batch_size=real.shape[0])

                # TRAIN DISC #
                self.set_grad_req(d=True, g=False)
                self.train_disc(fake, real)

                # append all losses in loss dict #
                [
                    self.loss_epoch_dict[loss].append(
                        self.loss_batch_dict[loss].data.item())
                    for loss in self.losses
                ]
                self.loop_iter += 1
                self.current_iter += 1

            self.current_epoch += 1

            if self.loop_iter % params['save_img_every'] == 0:
                helper.show_test(
                    real,
                    fake,
                    self.t2v(self.model_dict['M'].textures.unsqueeze(0).view(
                        1, params['grid_res'] - 1, params['grid_res'] - 1,
                        48).permute(0, 3, 1, 2)),
                    self.transform,
                    save=
                    f'output/{params["save_root"]}_{self.current_epoch}.jpg')
            save_str = self.save_state(
                f'output/{params["save_root"]}_{self.current_epoch}.json')
            print(save_str)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            [
                self.train_hist_dict[loss].append(
                    helper.mft(self.loss_epoch_dict[loss]))
                for loss in self.losses
            ]
            print(f'Epoch:{self.current_epoch}, Epoch Time:{per_epoch_ptime}')
            [
                print(
                    f'Train {loss}: {helper.mft(self.loss_epoch_dict[loss])}')
                for loss in self.losses
            ]

        self.display_history()
        print('Hit End of Learning Schedule!')
Exemplo n.º 4
0
    def train(self):
        # Train following learning rate schedule
        params = self.params
        done = False
        while not done:
            # clear last epochs losses
            for loss in self.losses:
                self.loss_epoch_dict[loss] = []

            self.model_dict["D"].train()
            self.model_dict["G"].train()
            self.set_grad_req(d=True, g=True)

            epoch_start_time = time.time()
            num_iter = 0

            print(f"Sched Cycle:{self.current_cycle}, Sched Iter:{self.current_iter}, Sched Epoch:{self.current_epoch}")
            [print(f"Learning Rate({opt}): {self.opt_dict[opt].param_groups[0]['lr']}") for opt in
             self.opt_dict.keys()]

            for (real_a, real_b) in tqdm(self.train_loader):

                if self.current_iter > len(self.iter_stack) - 1:
                    done = True
                    self.display_history()
                    print('Hit End of Learning Schedule!')
                    break

                # set learning rate
                lr_mult, save = self.lr_lookup()
                self.opt_dict["D"].param_groups[0]['lr'] = lr_mult * params['lr_disc']
                self.opt_dict["G"].param_groups[0]['lr'] = lr_mult * params['lr_gen']

                real_a, real_b = Variable(real_a.cuda()), Variable(real_b.cuda())

                # GENERATE
                fake_b = self.model_dict["G"](real_a)
                # TRAIN DISCRIMINATOR
                self.train_disc(real_a, real_b, fake_b)
                # TRAIN GENERATOR
                self.train_gen(real_a, real_b, fake_b)

                # append all losses in loss dict
                [self.loss_epoch_dict[loss].append(self.loss_batch_dict[loss].data[0]) for loss in self.losses]

                if save:
                    save_str = self.save_state(f'output/{params["save_root"]}_{self.current_epoch}.json')
                    tqdm.write(save_str)
                    self.current_epoch += 1

                self.current_iter += 1
                num_iter += 1

            # generate test images and save to disk
            helper.show_test(params,
                             self.transform,
                             self.test_loader,
                             self.model_dict['G'],
                             save=f'output/{params["save_root"]}_{self.current_cycle}.jpg')

            # run validation set loop to get losses
            self.test_loop()

            if not done:
                self.current_cycle += 1
                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time
                print(f'Epoch Training Training Time: {per_epoch_ptime}')
                [print(f'Train {loss}: {helper.mft(self.loss_epoch_dict[loss])}') for loss in self.losses]
                [print(f'Val {loss}: {helper.mft(self.loss_epoch_dict_test[loss])}') for loss in self.losses]
                print('\n')
                [self.train_hist_dict[loss].append(helper.mft(self.loss_epoch_dict[loss])) for loss in self.losses]
Exemplo n.º 5
0
    def train(self):
        # Train following learning rate schedule
        params = self.params
        while self.current_epoch < params["train_epoch"]:
            # clear last epochs losses
            for loss in self.losses:
                self.loss_epoch_dict[loss] = []

            self.model_dict["G"].train()
            epoch_start_time = time.time()
            num_iter = 0

            print(
                f"Sched Sched Iter:{self.current_iter}, Sched Epoch:{self.current_epoch}"
            )
            [
                print(
                    f"Learning Rate({opt}): {self.opt_dict[opt].param_groups[0]['lr']}"
                ) for opt in self.opt_dict.keys()
            ]
            for real_vgg, real_default in tqdm(self.train_loader):
                real_vgg = Variable(real_vgg.cuda())
                real_default = Variable(real_default.cuda())

                # TRAIN GENERATOR
                style_losses, content_losses = self.train_gen(
                    real_default, real_vgg)

                # append all losses in loss dict
                [
                    self.loss_epoch_dict[loss].append(
                        self.loss_batch_dict[loss].item())
                    for loss in self.losses
                ]
                self.current_iter += 1
                num_iter += 1

            # generate test images and save to disk
            if self.current_epoch % params["save_img_every"] == 0:
                helper.show_test(
                    params,
                    self.transform,
                    self.tensor_transform,
                    self.test_loader,
                    self.style,
                    self.model_dict['G'],
                    save=
                    f'output/{params["save_root"]}_val_{self.current_epoch}.jpg'
                )

            # run validation set loop to get losses
            self.test_loop()
            if self.current_epoch % params["save_every"] == 0:
                save_str = self.save_state(
                    f'output/{params["save_root"]}_{self.current_epoch}.json')
                tqdm.write(save_str)

            self.current_epoch += 1
            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            print(f'Epoch Training Training Time: {per_epoch_ptime}')
            [
                print(
                    f'Train {loss}: {helper.mft(self.loss_epoch_dict[loss])}')
                for loss in self.losses
            ]
            [
                print(
                    f'Val {loss}: {helper.mft(self.loss_epoch_dict_test[loss])}'
                ) for loss in self.losses
            ]
            print('\n')
            [
                self.train_hist_dict[loss].append(
                    helper.mft(self.loss_epoch_dict[loss]))
                for loss in self.losses
            ]

        self.display_history()
        print('Hit End of Learning Schedule!')
Exemplo n.º 6
0
    def train(self):
        # Train loop using custom batch feeder to pull samples
        params = self.params
        for epoch in range(params["train_epoch"]):

            # clear last epopchs losses
            for loss in self.losses:
                self.loss_epoch_dict[loss] = []

            print(
                f"Sched Iter:{self.current_iter}, Sched Epoch:{self.current_epoch}"
            )
            [
                print(
                    f"Learning Rate({opt}): {self.opt_dict[opt].param_groups[0]['lr']}"
                ) for opt in self.opt_dict.keys()
            ]

            self.model_dict["G"].train()
            self.model_dict["D"].train()

            batch_feeder = helper.BatchFeeder(self.train_loader)

            epoch_iter_count = 0
            epoch_start_time = time.time()

            # Run progress bar for length of dataset
            with tqdm(total=self.data_len) as epoch_bar:

                while epoch_iter_count < self.data_len:

                    # Set Discriminator loop length, should start large and then
                    disc_loop_total = 100 if ((self.current_iter < 25) or
                                              (self.current_iter %
                                               500 == 0)) else 5
                    self.set_grad_req(d=True, g=False)

                    # TRAIN DISC #
                    disc_loop_count = 0
                    while (disc_loop_count < disc_loop_total
                           ) and epoch_iter_count < self.data_len:
                        data_iter = batch_feeder.get_new()
                        self.train_disc(data_iter)

                        disc_loop_count += 1
                        epoch_iter_count += 1
                        epoch_bar.update()

                    # TRAIN GEN #
                    self.set_grad_req(d=False, g=True)
                    self.train_gen()

                    # append all losses in loss dict #
                    [
                        self.loss_epoch_dict[loss].append(
                            self.loss_batch_dict[loss].data[0])
                        for loss in self.losses
                    ]
                    self.current_iter += 1

            self.current_epoch += 1

            if self.current_epoch % params['save_every'] == 0:
                helper.show_test(
                    self.model_dict['G'],
                    Variable(self.preview_noise),
                    self.transform,
                    save=
                    f'output/{params["save_root"]}_{self.current_epoch}.jpg')

                save_str = self.save_state(
                    f'output/{params["save_root"]}_{self.current_epoch}.json')
                print(save_str)

            epoch_end_time = time.time()
            per_epoch_ptime = epoch_end_time - epoch_start_time
            [
                self.train_hist_dict[loss].append(
                    helper.mft(self.loss_epoch_dict[loss]))
                for loss in self.losses
            ]
            print(f'Epoch:{self.current_epoch}, Epoch Time:{per_epoch_ptime}')
            [
                print(
                    f'Train {loss}: {helper.mft(self.loss_epoch_dict[loss])}')
                for loss in self.losses
            ]

        self.display_history()
        print('Hit End of Learning Schedule!')
    def train(self):
        # Train following learning rate schedule
        params = self.params
        done = False
        while not done:
            # clear last epochs losses
            for loss in self.losses:
                self.loss_epoch_dict[loss] = []

            epoch_start_time = time.time()
            num_iter = 0

            print(
                f"Sched Cycle:{self.current_cycle}, Sched Iter:{self.current_iter}, Sched Epoch:{self.current_epoch}"
            )
            [
                print(
                    f"Learning Rate({opt}): {self.opt_dict[opt].param_groups[0]['lr']}"
                ) for opt in self.opt_dict.keys()
            ]

            for (real_a, real_b) in tqdm(self.train_loader):

                if self.current_iter > len(self.iter_stack) - 1:
                    done = True
                    self.display_history()
                    print('Hit End of Learning Schedule!')
                    break

                lr_mult, save = self.lr_lookup()
                self.opt_dict["D_A"].param_groups[0][
                    'lr'] = lr_mult * params['lr_disc']
                self.opt_dict["D_B"].param_groups[0][
                    'lr'] = lr_mult * params['lr_disc']
                self.opt_dict["G"].param_groups[0][
                    'lr'] = lr_mult * params['lr_gen']

                real_a, real_b = Variable(real_a.cuda()), Variable(
                    real_b.cuda())

                # traing generator
                self.opt_dict["G"].zero_grad()

                # generate fake b and discriminate
                fake_b = self.model_dict["G_A"](real_a)
                d_a_result, d_a_fake_feats = self.model_dict["D_A"](fake_b)
                self.loss_batch_dict['G_A_loss'] = self.BCE_loss(
                    d_a_result, Variable(torch.ones(d_a_result.size()).cuda()))

                # reconstruct a
                rec_a = self.model_dict["G_B"](fake_b)
                self.loss_batch_dict['Cycle_A_loss'] = self.L1_loss(
                    rec_a, real_a) * params['cycle_loss_A']

                # generate fake a and discriminate
                fake_a = self.model_dict["G_B"](real_b)
                d_b_result, d_b_fake_feats = self.model_dict["D_B"](fake_a)
                self.loss_batch_dict['G_B_loss'] = self.BCE_loss(
                    d_b_result, Variable(torch.ones(d_b_result.size()).cuda()))
                # reconstruct b
                rec_b = self.model_dict["G_A"](fake_a)
                self.loss_batch_dict['Cycle_B_loss'] = self.L1_loss(
                    rec_b, real_b) * params['cycle_loss_B']

                self.opt_dict["D_A"].zero_grad()
                self.opt_dict["D_B"].zero_grad()

                # discriminate real samples
                d_a_real, d_a_real_feats = self.model_dict["D_A"](real_b)
                d_b_real, d_b_real_feats = self.model_dict["D_B"](real_a)

                d_a_mean_loss, d_a_std_loss = mean_std_loss(
                    d_a_real_feats, d_a_fake_feats)
                d_b_mean_loss, d_b_std_loss = mean_std_loss(
                    d_b_real_feats, d_b_fake_feats)

                # calculate feature loss
                self.loss_batch_dict[
                    'D_A_feat_loss'] = d_a_std_loss + d_a_mean_loss
                self.loss_batch_dict[
                    'D_B_feat_loss'] = d_b_std_loss + d_b_mean_loss

                # addup generator a and b loss and step
                g_a_loss_total = self.loss_batch_dict[
                    'G_A_loss'] * .5 + self.loss_batch_dict[
                        'D_B_feat_loss'] * .5
                g_b_loss_total = self.loss_batch_dict[
                    'G_B_loss'] * .5 + self.loss_batch_dict[
                        'D_A_feat_loss'] * .5

                g_loss = g_a_loss_total + g_b_loss_total + self.loss_batch_dict[
                    'Cycle_A_loss'] + self.loss_batch_dict['Cycle_B_loss']
                g_loss.backward(retain_graph=True)
                self.opt_dict["G"].step()

                # train discriminator a
                d_a_real_loss = self.BCE_loss(
                    d_a_real, Variable(torch.ones(d_a_real.size()).cuda()))

                fake_b = self.fakeB_pool.query(fake_b)
                d_a_fake, d_a_fake_feats = self.model_dict["D_A"](fake_b)
                d_a_fake_loss = self.BCE_loss(
                    d_a_fake, Variable(torch.zeros(d_a_fake.size()).cuda()))

                # add up disc a loss and step
                self.loss_batch_dict['D_A_loss'] = (d_a_real_loss +
                                                    d_a_fake_loss) * .5
                self.loss_batch_dict['D_A_loss'].backward()
                self.opt_dict["D_A"].step()

                # train discriminator b
                d_b_real_loss = self.BCE_loss(
                    d_b_real, Variable(torch.ones(d_b_real.size()).cuda()))

                fake_a = self.fakeA_pool.query(fake_a)
                d_b_fake, d_b_fake_feats = self.model_dict["D_B"](fake_a)
                d_b_fake_loss = self.BCE_loss(
                    d_b_fake, Variable(torch.zeros(d_b_fake.size()).cuda()))

                # add up disc b  loss and step
                self.loss_batch_dict['D_B_loss'] = (d_b_real_loss +
                                                    d_b_fake_loss) * .5
                self.loss_batch_dict['D_B_loss'].backward()
                self.opt_dict["D_B"].step()

                # append all losses in loss dict
                [
                    self.train_hist_dict[loss].append(
                        self.loss_batch_dict[loss].data[0])
                    for loss in self.losses
                ]
                [
                    self.loss_epoch_dict[loss].append(
                        self.loss_batch_dict[loss].data[0])
                    for loss in self.losses
                ]

                if save:
                    save_str = self.save_state(
                        f'output/{params["save_root"]}_{self.current_epoch}.json'
                    )
                    tqdm.write(save_str)
                    self.current_epoch += 1

                self.current_iter += 1
                num_iter += 1

            helper.show_test(
                self.model_dict['G_A'],
                self.model_dict['G_B'],
                params,
                save=f'output/{params["save_root"]}_{self.current_cycle}.jpg')
            if not done:
                self.current_cycle += 1
                epoch_end_time = time.time()
                per_epoch_ptime = epoch_end_time - epoch_start_time
                self.train_hist_dict['per_epoch_ptimes'].append(
                    per_epoch_ptime)
                print(
                    f'Epoch:{self.current_epoch}, Epoch Time:{per_epoch_ptime}'
                )
                [
                    print(f'{loss}: {helper.mft(self.loss_epoch_dict[loss])}')
                    for loss in self.losses
                ]