Ejemplo n.º 1
0
    def fit(self, train_x, train_y, val_x, val_y, bin_size, lr, batch_size,
            with_gap, earlystop, verbose):
        self.batch_size = batch_size
        optimizer = torch.optim.Adam(self.model.parameters(),
                                     lr=lr,
                                     amsgrad=True)
        loss_list = []
        for i in range(self.maxepochs):
            optimizer.zero_grad()
            x, y = self.loader.get_batch()
            x = x.to(self.device)
            y = y.yo(self.device)

            ypred = self.model(x)

            if ypred.dim() == 2:
                ypred = ypred.squeeze(1)
            assert ypred.size() == y.size()
            loss = MSELoss(reduction='mean')(ypred, y)
            loss.backward()
            optimizer.step()
            if earlystop == True:
                loss_list = loss_list.append(self.evaluate(val_x, val_y))
            else:
                if train_x.size()[0] == batch_size:
                    loss_list = loss_list.append(loss.cpu().data.numpy())

            if len(loss_list)>5 \
            and abs(loss_list[-2]/loss_list[-1]-1)<0.0001  :
                break
        if self.earlystop == True:
            return None, loss_list[-1]
        else:
            return loss_list[-1], self.evaluate(val_x, val_y)
Ejemplo n.º 2
0
    def compute_q_loss(global_state, state, reward, next_global_state,
                       next_state):
        if args.cuda:
            global_state = global_state.cuda()
            state = state.cuda()
            reward = reward.cuda()
            next_global_state = next_global_state.cuda()
            next_state = next_state.cuda()

        global_state = Variable(global_state, requires_grad=True)
        state = Variable(state, requires_grad=True)
        next_global_state = Variable(next_global_state, volatile=True)
        next_state = Variable(next_state, volatile=True)

        current_q_values, _ = training_agents[0].act(global_state, state)
        max_next_q_values, _ = training_agents[0].target_act(
            next_global_state, next_state)
        max_next_q_values = max_next_q_values.max(1)[0]
        # sum the rewards for individual agents
        expected_q_values = Variable(
            reward.mean(dim=1)) + args.gamma * max_next_q_values

        loss = MSELoss()(current_q_values, expected_q_values)
        loss.backward()

        return loss.cpu().data[0]
Ejemplo n.º 3
0
def inference_demo(screen):
    forces = [hor_impulse]
    ground_truth_mass = Variable(torch.DoubleTensor([7]))
    world, c = make_world(forces, ground_truth_mass, num_links=NUM_LINKS)

    rec = None
    # rec = Recorder(DT, screen)
    ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    ground_truth_pos = [p.data for p in ground_truth_pos]
    ground_truth_pos = Variable(torch.cat(ground_truth_pos))

    learning_rate = 0.01
    max_iter = 100

    next_mass = Variable(torch.DoubleTensor([1.3]), requires_grad=True)
    loss_hist = []
    mass_hist = [next_mass]
    last_dist = 1e10
    for i in range(max_iter):
        world, c = make_world(forces, next_mass, num_links=NUM_LINKS)
        # world.load_state(initial_state)
        # world.reset_engine()
        positions = positions_run_world(world, run_time=10, screen=None)
        positions = torch.cat(positions)
        positions = positions[:len(ground_truth_pos)]
        # temp_ground_truth_pos = ground_truth_pos[:len(positions)]

        loss = MSELoss()(positions, ground_truth_pos)
        loss.backward()
        grad = c.mass.grad.data
        # clip gradient
        grad = torch.max(torch.min(grad, torch.DoubleTensor([100])), torch.DoubleTensor([-100]))
        temp = c.mass.data - learning_rate * grad
        temp = max(MASS_EPS, temp[0])
        next_mass = Variable(torch.DoubleTensor([temp]), requires_grad=True)
        # learning_rate /= 1.1
        print(i, '/', max_iter, loss.data[0])
        print(grad)
        print(next_mass)
        # print(learned_force(0.05))
        if abs((last_dist - loss).data[0]) < 1e-3:
            break
        last_dist = loss
        loss_hist.append(loss)
        mass_hist.append(next_mass)

    world = make_world(forces, next_mass, num_links=NUM_LINKS)[0]
    # world.load_state(initial_state)
    # world.reset_engine()
    rec = None
    # rec = Recorder(DT, screen)
    positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    loss = MSELoss()(positions, ground_truth_pos)
    print(loss.data[0])
    print(next_mass)

    plot(loss_hist)
    plot(mass_hist)
    def one_batch_train(self, Xs, Y):
        Ws = Xs['word_vectors']

        self.optimizer.zero_grad()
        outputs = self.model(Ws)
        mse_loss = MSELoss(outputs, Y)

        mse_loss.backward()
        self.optimizer.step()
        self.save_checkpoint('../../../data/nn/lstm.pth.tar')
        return self.model.data.cpu().numpy(), mse_loss.item()
Ejemplo n.º 5
0
    def update(self, curr_states, next_states, actions, rewards, terminals,
               discount, curr_model, eval_model, optimizer):
        non_final_mask = 1 - torch.ByteTensor(terminals)
        curr_states = Variable(curr_states, requires_grad=True)
        next_states = Variable(next_states, requires_grad=False)
        reward_batch = torch.FloatTensor(rewards)

        actions = torch.LongTensor(actions).unsqueeze(1)
        repeated_actions = actions.repeat(1,
                                          len(self.reward_types)).unsqueeze(2)

        # make prediction for decomposed q-values for current_state
        decomp_q_curr_state, q_curr_state = curr_model(curr_states)
        decomp_q_curr_state = decomp_q_curr_state.gather(2, repeated_actions)
        q_curr_state = q_curr_state.gather(1, actions)

        # make prediction for decomposed q-values for next_state
        decomp_q_next_state, q_next_state = curr_model(
            next_states[non_final_mask])
        q_next_state, q_next_state_actions = q_next_state.max(1)
        q_next_state_actions = q_next_state_actions.unsqueeze(1).repeat(
            1, len(self.reward_types)).unsqueeze(2)

        # Calculate Targets
        target_q = Variable(torch.zeros(q_curr_state.shape),
                            requires_grad=False)
        target_q[non_final_mask] = q_next_state.detach().unsqueeze(1)
        target_q = discount * (reward_batch.sum(1, keepdim=True) + target_q)

        target_decomp_q = Variable(torch.zeros(decomp_q_curr_state.shape),
                                   requires_grad=False)
        target_decomp_q[non_final_mask] = decomp_q_next_state.gather(
            2, q_next_state_actions)
        target_decomp_q = discount * (reward_batch.unsqueeze(2) +
                                      target_decomp_q)

        # calculate loss
        loss = MSELoss()(q_curr_state, target_q)
        loss += MSELoss()(decomp_q_curr_state, target_decomp_q)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(curr_model.parameters(), 100)
        optimizer.step()
Ejemplo n.º 6
0
def train_model():
    model.train()

    total_loss = 0
    pbar = tqdm(train_loader)
    for data in pbar:
        if not args.multi_gpu: data = data.to(device)
        optimizer.zero_grad()
        #print(data)
        #print(args)
        if args.use_orig_graph:
            x = data.x if args.use_feature else None
            edge_weight = data.edge_weight if args.use_edge_weight else None
            node_id = data.node_id if emb else None
            out = model(data.z, data.edge_index, data.batch, x, edge_weight,
                        node_id).view(-1)
        else:
            out = model(data).view(-1)
        if args.multi_gpu:
            y = torch.cat([d.y.to(torch.float) for d in data]).to(out.device)
        else:
            y = data.y.to(torch.float)

        if args.neg_edge_percent != 100:
            y_neg = y[y == 0]
            out_neg = out[y == 0]
            y_pos = y[y != 0]
            out_pos = out[y != 0]

            num_neg = int(args.neg_edge_percent / 100 * len(out_neg))
            out_neg, y_neg = out_neg[:num_neg], y_neg[:num_neg]

            y = torch.cat([y_pos, y_neg])
            out = torch.cat([out_pos, out_neg])

        loss = MSELoss()(out, y)
        loss = torch.sqrt(loss)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * len(data)

    return total_loss / len(train_dataset)
Ejemplo n.º 7
0
def main(screen):
    forces = [hor_impulse]
    ground_truth_mass = torch.tensor([TOTAL_MASS], dtype=DTYPE)
    world, chain = make_world(forces, ground_truth_mass, num_links=NUM_LINKS)

    rec = None
    # rec = Recorder(DT, screen)
    ground_truth_pos = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    ground_truth_pos = [p.data for p in ground_truth_pos]
    ground_truth_pos = torch.cat(ground_truth_pos)

    learning_rate = 0.5
    max_iter = 100

    next_mass = torch.rand_like(ground_truth_mass, requires_grad=True)
    print('\rInitial mass:', next_mass.item())
    print('-----')

    optim = torch.optim.RMSprop([next_mass], lr=learning_rate)
    loss_hist = []
    mass_hist = [next_mass.item()]
    last_loss = 1e10
    for i in range(max_iter):
        if i % 1 == 0:
            world, chain = make_world(forces, next_mass.clone().detach(), num_links=NUM_LINKS)
            run_world(world, run_time=10, print_time=False, screen=None, recorder=None)

        world, chain = make_world(forces, next_mass, num_links=NUM_LINKS)
        positions = positions_run_world(world, run_time=10, screen=None)
        positions = torch.cat(positions)
        positions = positions[:len(ground_truth_pos)]
        clipped_ground_truth_pos = ground_truth_pos[:len(positions)]

        optim.zero_grad()
        loss = MSELoss()(positions, clipped_ground_truth_pos)
        loss.backward()

        optim.step()

        print('Iteration: {} / {}'.format(i+1, max_iter))
        print('Loss:', loss.item())
        print('Gradient:', next_mass.grad.item())
        print('Next mass:', next_mass.item())
        print('-----')
        if abs((last_loss - loss).item()) < STOP_DIFF:
            print('Loss changed by less than {} between iterations, stopping training.'
                  .format(STOP_DIFF))
            break
        last_loss = loss
        loss_hist.append(loss.item())
        mass_hist.append(next_mass.item())

    world = make_world(forces, next_mass, num_links=NUM_LINKS)[0]
    rec = None
    positions = positions_run_world(world, run_time=10, screen=screen, recorder=rec)
    positions = torch.cat(positions)
    positions = positions[:len(ground_truth_pos)]
    clipped_ground_truth_pos = ground_truth_pos[:len(positions)]
    loss = MSELoss()(positions, clipped_ground_truth_pos)
    print('Final loss:', loss.item())
    print('Final mass:', next_mass.item())

    plot(loss_hist)
    plot(mass_hist)
Ejemplo n.º 8
0
def main(screen):
    dateTimeObj = datetime.now()

    timestampStr = dateTimeObj.strftime("%d-%b-%Y(%H:%M)")
    print('Current Timestamp : ', timestampStr)

    if not os.path.isdir(timestampStr):
        os.mkdir(timestampStr)

    #if torch.cuda.is_available():
    #    dev = "cuda:0"
    #else:
    #    dev = "cpu"

    forces = []

    rec = None
    #rec = Recorder(DT, screen)

    #plot(zhist)

    learning_rate = 0.5
    max_iter = 100

    plt.subplots(21)
    plt.subplot(212)
    plt.gca().invert_yaxis()

    optim = torch.optim.RMSprop([utT], lr=learning_rate)

    last_loss = 1e10
    lossHist = []

    for i in range(1, 20000):
        world, chain = make_world(forces, ctT, utT)

        hist = positions_run_world(world,
                                   run_time=runtime,
                                   screen=screen,
                                   recorder=rec)
        xy = hist[0]
        vel = hist[1]
        control = hist[2]
        xy = torch.cat(xy).to(device=Defaults.DEVICE)
        x1 = xy[0::2]
        y1 = xy[1::2]
        control = torch.cat(control).to(device=Defaults.DEVICE)
        optim.zero_grad()

        targetxy = []
        targetControl = []
        j = 0
        while (j < xy.size()[0]):
            targetxy.append(500)
            targetxy.append(240)

            targetControl.append(0)
            targetControl.append(0)

            j = j + 2

        tt = torch.tensor(targetxy,
                          requires_grad=True,
                          dtype=DTYPE,
                          device=Defaults.DEVICE).t()
        tc = torch.tensor(targetControl,
                          requires_grad=True,
                          dtype=DTYPE,
                          device=Defaults.DEVICE).t()
        #loss = MSELoss()(zhist, 150*torch.tensor(np.ones(zhist.size()),requires_grad=True, dtype=DTYPE,device=Defaults.DEVICE))/100
        #loss = MSELoss()(xy, tt) / 10 + 0*MSELoss()(control, tc) / 1000 + abs(vel[-1][0]) + abs(vel[-1][1]) + abs(vel[-1][2])

        loss = MSELoss()(xy, tt) / 10 + 0 * MSELoss()(
            control, tc) / 1000 + abs(vel[-1][0]) + abs(vel[-1][1]) + abs(
                vel[-1][2])
        #loss = zhist[-1]
        loss.backward()

        lossHist.append(loss.item())
        optim.step()

        print('Loss:', loss.item())
        print('Gradient:', utT.grad)
        print('Next u:', utT)

        #plt.axis([-1, 1, -1, 1])
        plt.ion()
        plt.show()

        plt.subplot(211)

        plt.plot(lossHist)
        plt.draw()

        pl1 = xy.cpu()
        plt.subplot(212)

        plt.plot(pl1.clone().detach().numpy()[::4],
                 pl1.clone().detach().numpy()[1::4])
        #plt.plot(zhist)
        plt.draw()

        plt.pause(0.001)

        plt.savefig('2step.png')

        if (i % 20) == 0:
            plt.subplot(211)
            plt.cla()
            plt.subplot(212)
            plt.cla()
            plt.gca().invert_yaxis()
            lossHist.clear()
            torch.save(utT, timestampStr + '/file' + str(i) + '.pt')

        utTCurrent = utT.clone()

    world, chain = make_world(forces, ctT, utTCurrent)

    positions_run_world(world, run_time=runtime, screen=None, recorder=rec)
    def train(self, epochs, load=False):
        print('TRAINING MODEL:'
              ' BATCH_SIZE = ' + str(Trainer.BATCH_SIZE) + ', PARTICLE_DIM: ' +
              str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) +
              ', PRTCL_LATENT_SPACE_SIZE: ' +
              str(self.PRTCL_LATENT_SPACE_SIZE))

        encoder, decoder, discriminator = self.create_model()
        enc_optim = torch.optim.Adam(encoder.parameters(), lr=self.LR)
        dec_optim = torch.optim.Adam(decoder.parameters(), lr=self.LR)
        dis_optim = torch.optim.Adam(discriminator.parameters(), lr=self.LR)

        embedder = self.create_embedder()
        deembedder = self.create_deembedder()

        if load:
            print('LOADING MODEL STATES...')
            encoder.load_state_dict(torch.load(Trainer.ENCODER_SAVE_PATH))
            decoder.load_state_dict(torch.load(Trainer.DECODER_SAVE_PATH))
            discriminator.load_state_dict(
                torch.load(Trainer.DISCRIMINATOR_SAVE_PATH))
            embedder.load_state_dict(torch.load(Trainer.PDG_EMBED_SAVE_PATH))
            deembedder.load_state_dict(
                torch.load(Trainer.PDG_DEEMBED_SAVE_PATH))

        print('AUTOENCODER')
        print(encoder)
        print(decoder)
        print(discriminator)
        print('EMBEDDER')
        print(embedder)
        print('DEEMBEDDER')
        print(deembedder)

        _data = load_data()
        data_train, data_valid = self.prep_data(_data,
                                                batch_size=Trainer.BATCH_SIZE,
                                                valid=0.1)

        particles = torch.tensor(particle_idxs(), device=self.device)
        particles.requires_grad = False

        for epoch in range(epochs):

            for n_batch, batch in enumerate(data_train):
                encoder.zero_grad()
                decoder.zero_grad()
                discriminator.zero_grad()

                real_data: torch.Tensor = batch.to(self.device)
                emb_data = self.embed_data(real_data, [embedder]).detach()

                batch_size = len(batch)

                zeros = torch.zeros(batch_size,
                                    device=self.device,
                                    requires_grad=False)
                ones = torch.ones(batch_size,
                                  device=self.device,
                                  requires_grad=False)

                # ======== Train Discriminator ======== #
                decoder.freeze(True)
                encoder.freeze(True)
                discriminator.freeze(False)

                lat_fake = torch.randn(batch_size,
                                       self.PRTCL_LATENT_SPACE_SIZE,
                                       device=self.device)
                disc_fake = discriminator(lat_fake)

                lat_real = encoder(emb_data)
                disc_real = discriminator(lat_real)

                loss_fake = MSELoss()(disc_fake, zeros)
                loss_real = MSELoss()(disc_real, ones)

                loss_fake.backward()
                loss_real.backward()

                dis_optim.step()

                # ======== Train Generator ======== #
                decoder.freeze(False)
                encoder.freeze(False)
                discriminator.freeze(True)

                lat_real = encoder(emb_data)
                recon_data = decoder(lat_real)
                d_real = discriminator(encoder(emb_data))

                recon_loss = MSELoss()(emb_data, recon_data)
                d_loss = MSELoss()(d_real, zeros)

                recon_loss.backward()
                d_loss.backward()

                enc_optim.step()
                dec_optim.step()

                self.train_deembeders([
                    (particles, embedder, deembedder),
                ],
                                      epochs=2)

                if n_batch % 100 == 0:
                    self.print_deemb_quality(particles, embedder, deembedder)

                    self.show_heatmaps(emb_data[:30, :],
                                       recon_data[:30, :],
                                       reprod=False,
                                       save=True,
                                       epoch=epoch,
                                       batch=n_batch)
                    err_kld, err_wass = self.gen_show_comp_hists(
                        decoder,
                        _data,
                        attr_idxs=[
                            FEATURES - 8, FEATURES - 7, FEATURES - 6,
                            FEATURES - 5
                        ],
                        embedders=[embedder],
                        emb=False,
                        deembedder=deembedder,
                        save=True,
                        epoch=epoch,
                        batch=n_batch)

                    self.errs_kld.append(err_kld)
                    self.errs_wass.append(err_wass)

                    valid_loss = self._valid_loss(encoder, decoder, embedder,
                                                  data_valid)

                    print(
                        f'Epoch: {str(epoch)}/{epochs} :: '
                        f'Batch: {str(n_batch)}/{str(len(data_train))} :: '
                        f'train loss: {"{:.6f}".format(round(recon_loss.item(), 6))} :: '
                        f'valid loss: {"{:.6f}".format(round(valid_loss, 6))} :: '
                        f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: '
                        f'err wass: {"{:.6f}".format(round(err_wass, 6))}')

            self._save_models(encoder, decoder, discriminator, embedder,
                              deembedder)

            with open(self.ERRS_SAVE_PATH, 'wb') as handle:
                pickle.dump((self.errs_kld, self.errs_wass),
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        return encoder, decoder, discriminator, embedder, deembedder
Ejemplo n.º 10
0
def train(*args, **kwargs):
    print(kwargs)

    device = kwargs['train']['device']

    # 1. Init audio preprocessing
    preproc = AudioPreprocessor(**kwargs['preprocessing_params'])
    sr = kwargs['preprocessing_params']['sample_rate']

    # 2. Load preprocessing net
    preproc_net = torch.load(kwargs['preproc_net_fname']).to(device)

    # 3. Init model dynamics net
    md_net = LstmModelDynamics(**kwargs['model_dynamics_params']).to(device)
    optim = torch.optim.Adam(md_net.parameters(), lr=kwargs['train']['learning_rate'], eps=kwargs['train']['learning_rate_eps'])

    # 4. Load training set
    data_fname = kwargs['data_fname']
    df = pd.read_pickle(data_fname)

    # 5. Train loop
    params = kwargs['train']
    md_net.train()
    for i in range(params['num_steps']):
        sample = df.sample(n=kwargs['train']['minibatch_size'])
        states = np.stack(sample.loc[:, 'states'].values)
        actions = np.stack(sample.loc[:, 'actions'].values)
        audio = np.stack(sample.loc[:, 'audio'].values)

        preproc_audio = np.array([preproc(audio[j], sr) for j in range(audio.shape[0])])

        acoustic_states = torch.from_numpy(preproc_audio).float().to(device)
        # acoustic_states = acoustic_states.view(-1, kwargs['model_dynamics_params']["acoustic_state_dim"])
        # mean_norm = acoustic_states.mean(dim=0)
        # mean_std = acoustic_states.std(dim=0)
        # acoustic_states = (acoustic_states - mean_norm.view(1, -1)) / mean_std.view(1, -1)
        # acoustic_states = acoustic_states.view(kwargs['train']['minibatch_size'], -1, kwargs['model_dynamics_params']["acoustic_state_dim"])
        _, _, acoustic_states = preproc_net(torch.from_numpy(preproc_audio).float().to(device),
                                 seq_lens=np.array([preproc_audio.shape[-2]]))



        seq_len = actions.shape[1]
        acoustic_state_dim = kwargs['model_dynamics_params']["acoustic_state_dim"]

        # forward prop
        lstm_outs, predicted_acoustic_states = md_net(acoustic_states,
                                           torch.from_numpy(states[:, :seq_len, :]).float().to(device),
                                           torch.from_numpy(actions).float().to(device))

        # compute error
        loss = MSELoss(reduction='sum')(predicted_acoustic_states[:, :-1, :].contiguous().view(-1, acoustic_state_dim),
                                        acoustic_states[:, 1:, :].contiguous().view(-1, acoustic_state_dim)) / (seq_len * kwargs['train']['minibatch_size'])

        # backprop
        optim.zero_grad()
        loss.backward()
        optim.step()

        dynamics = MSELoss(reduction='sum')(acoustic_states[:, :-1, :].contiguous().view(-1, acoustic_state_dim),
                                        acoustic_states[:, 1:, :].contiguous().view(-1, acoustic_state_dim)) / (seq_len * kwargs['train']['minibatch_size'])

        print("\rstep: {} | loss: {:.4f}| actual_dynamics: {:.4f}".format(i, loss.detach().cpu().item(), dynamics.detach().cpu().item()), end="")
Ejemplo n.º 11
0
    def train(self, epochs, load=False):
        print('TRAINING MODEL:'
              ' BATCH_SIZE = ' + str(self.BATCH_SIZE) + ', PARTICLE_DIM: ' +
              str(PARTICLE_DIM) + ', EPOCHS: ' + str(epochs) +
              ', PRTCL_LATENT_SPACE_SIZE: ' +
              str(self.PRTCL_LATENT_SPACE_SIZE))

        generator, discriminator = self.create_model()
        gen_optim = torch.optim.Adam(generator.parameters(),
                                     lr=self.LR,
                                     betas=(0, .9))
        dis_optim = torch.optim.Adam(discriminator.parameters(),
                                     lr=self.LR,
                                     betas=(0, .9))

        embedder = self.create_embedder()
        deembedder = self.create_deembedder()

        particles = torch.tensor(particle_idxs(), device=self.device)

        if load:
            print('LOADING MODEL STATES...')
            generator.load_state_dict(torch.load(Trainer.GENERATOR_SAVE_PATH))
            discriminator.load_state_dict(
                torch.load(Trainer.DISCRIMINATOR_SAVE_PATH))
            embedder.load_state_dict(torch.load(Trainer.PDG_EMBED_SAVE_PATH))
            deembedder.load_state_dict(
                torch.load(Trainer.PDG_DEEMBED_SAVE_PATH))

        print('GENERATOR')
        print(generator)
        print('DISCRIMINATOR')
        print(discriminator)
        print('EMBEDDER')
        print(embedder)
        print('DEEMBEDDER')
        print(deembedder)

        _data = load_data()
        data_train, data_valid = self.prep_data(_data,
                                                batch_size=self.BATCH_SIZE,
                                                valid=0.1)

        for epoch in range(epochs):

            for n_batch, batch in enumerate(data_train):

                real_data: torch.Tensor = batch.to(self.device)
                emb_data = self.embed_data(real_data, [embedder]).detach()

                batch_size = len(batch)

                valid = torch.ones(batch_size,
                                   device=self.device,
                                   requires_grad=False)
                fake = torch.zeros(batch_size,
                                   device=self.device,
                                   requires_grad=False)

                # ======== Train Generator ======== #
                gen_optim.zero_grad()

                # Sample noise as generator input
                lat_fake = torch.randn(batch_size,
                                       self.PRTCL_LATENT_SPACE_SIZE,
                                       device=self.device)
                lat_fake = Variable(
                    torch.tensor(np.random.normal(
                        0, 1, (batch_size, self.PRTCL_LATENT_SPACE_SIZE)),
                                 device=self.device).float())
                # Generate a batch of images
                gen_data = generator(lat_fake)

                # Loss measures generator's ability to fool the discriminator
                g_loss = MSELoss()(discriminator(gen_data), valid)

                g_loss.backward()
                gen_optim.step()

                # ======== Train Discriminator ======== #
                dis_optim.zero_grad()

                real_loss = MSELoss()(discriminator(emb_data), valid)
                fake_loss = MSELoss()(discriminator(gen_data.detach()), fake)
                d_loss = (real_loss + fake_loss) / 2

                d_loss.backward()
                dis_optim.step()

                self.train_deembeders([
                    (particles, embedder, deembedder),
                ],
                                      epochs=2)

                if n_batch % 100 == 0:
                    self.print_deemb_quality(
                        torch.tensor(particle_idxs(), device=self.device),
                        embedder, deembedder)

                    self.show_heatmaps(emb_data[:30, :],
                                       gen_data[:30, :],
                                       reprod=False,
                                       save=True,
                                       epoch=epoch,
                                       batch=n_batch)
                    err_kld, err_wass = self.gen_show_comp_hists(
                        generator,
                        _data,
                        attr_idxs=[
                            FEATURES - 8, FEATURES - 7, FEATURES - 6,
                            FEATURES - 5
                        ],
                        embedders=[embedder],
                        emb=False,
                        deembedder=deembedder,
                        save=True,
                        epoch=epoch,
                        batch=n_batch)

                    self.errs_kld.append(err_kld)
                    self.errs_wass.append(err_wass)

                    print(
                        f'Epoch: {str(epoch)}/{epochs} :: '
                        f'Batch: {str(n_batch)}/{str(len(data_train))} :: '
                        f'generator loss: {"{:.6f}".format(round(g_loss.item(), 6))} :: '
                        f'discriminator loss: {"{:.6f}".format(round(d_loss.item(), 6))} :: '
                        f'err kld: {"{:.6f}".format(round(err_kld, 6))} :: '
                        f'err wass: {"{:.6f}".format(round(err_wass, 6))}')

            self._save_models(generator, discriminator, embedder, deembedder)

            with open(self.ERRS_SAVE_PATH, 'wb') as handle:
                pickle.dump((self.errs_kld, self.errs_wass),
                            handle,
                            protocol=pickle.HIGHEST_PROTOCOL)

        return generator, discriminator, embedder, deembedder