示例#1
0
            if count_update_step_cgan == 1:
                viz = Visdom()
                x_value = np.asarray(count_update_step_cgan).reshape(1, )
                x_label = 'Training Step'
                y_value = np.column_stack(
                    (np.asarray(loss_dis.item()), np.asarray(loss_gen.item())))
                y_label = 'Loss'
                title = 'Discriminator and Generator Losses'
                legend = ['Loss_Dis', 'Loss_Gen']
                win_dis_gen = creat_vis_plot(viz, x_value, y_value, x_label,
                                             y_label, title, legend)
            elif count_update_step_cgan % 50 == 0:
                x_value = np.asarray(count_update_step_cgan).reshape(1, )
                y_value = np.column_stack(
                    (np.asarray(loss_dis.item()), np.asarray(loss_gen.item())))
                update_vis(viz, win_dis_gen, x_value, y_value)

            # evaluate the model
            if count_update_step_cgan % 1000 == 0:
                print('\nUpdate step: {:d}'
                      '\nmean loss_dis: {:.4f}'
                      '\nmean loss_gen: {:.4f}'.format(count_update_step_cgan,
                                                       loss_dis.item(),
                                                       loss_gen.item()))

            # save the midterm model sates
            if count_update_step_cgan % 5000 == 0:
                torch.save(
                    cgan.state_dict(), save_dir + '/cgan_step' +
                    str(count_update_step_cgan) + '.pt')
                # visualize_results_cgan(params, cgan, device, valid_loader)
示例#2
0
def training(params, encoder, decoder, optim_encoder, optim_decoder,
             lr_scheduler_encoder, lr_scheduler_decoder, device, attrs_class,
             decorr_regul, train_loader, valid_loader, save_dir):

    encoder.train()
    decoder.train()

    count_update_step = 0

    for i in range(params.n_epochs):

        for batch_idx, sample_batched in enumerate(train_loader):
            data, label = sample_batched['image'], sample_batched['attributes']

            batch_x = data.to(device)
            batch_y = label.to(device)

            count_update_step += 1

            ############################
            # (1) update encoder
            ############################
            for p in encoder.parameters():
                p.requires_grad_(True)
            for p in decoder.parameters():
                p.requires_grad_(False)
            optim_encoder.zero_grad()

            y_attrs, z_latent = encoder(batch_x)
            x_recons = decoder(batch_y, z_latent)

            # classification loss
            loss_class = attrs_class(y_attrs, batch_y)
            # decorrelation loss
            y_attrs_sigmoid = torch.sigmoid(y_attrs)
            loss_decorr = decorr_regul(y_attrs_sigmoid, z_latent)
            # image reconstruction loss
            loss_recons_image = torch.mean(
                0.5 * (batch_x.view(len(batch_x), -1) -
                       x_recons.view(len(x_recons), -1))**2)
            loss_encoder = params.lambda_class * loss_class + params.lambda_decorr * loss_decorr \
                           + params.lambda_recons * loss_recons_image

            loss_encoder.backward()
            optim_encoder.step()

            ############################
            # (2) update decoder
            ############################
            for p in decoder.parameters():
                p.requires_grad_(True)
            for p in encoder.parameters():
                p.requires_grad_(False)
            optim_decoder.zero_grad()

            y_attrs, z_latent = encoder(batch_x)
            x_recons = decoder(batch_y, z_latent.detach())

            # image reconstruction loss
            loss_recons_image = torch.mean(
                0.5 * (batch_x.view(len(batch_x), -1) -
                       x_recons.view(len(x_recons), -1))**2)
            loss_decoder = params.lambda_recons * loss_recons_image

            loss_decoder.backward()
            optim_decoder.step()

            # visualize losses
            if count_update_step == 1:
                viz = Visdom()
                x_value = np.asarray(count_update_step).reshape(1, )
                x_label = 'Training Step'

                y_value = np.asarray(loss_recons_image.item()).reshape(1, )
                y_label = 'MSE'
                title = 'Image Reconstruction Loss'
                legend = ['MSE']
                win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label,
                                                y_label, title, legend)

                y_value = np.asarray(loss_class.item()).reshape(1, )
                y_label = 'Loss'
                title = 'Binary Cross Entropy Loss'
                legend = ['Loss_BCE']
                win_attrs_class = creat_vis_plot(viz, x_value, y_value,
                                                 x_label, y_label, title,
                                                 legend)

                y_value = np.asarray(loss_decorr.item()).reshape(1, )
                y_label = 'Loss'
                title = 'Decorrelation Loss'
                legend = ['Loss_Decorr']
                win_decorr = creat_vis_plot(viz, x_value, y_value, x_label,
                                            y_label, title, legend)

            elif count_update_step % 50 == 0:
                x_value = np.asarray(count_update_step).reshape(1, )

                y_value = np.asarray(loss_recons_image.item()).reshape(1, )
                update_vis(viz, win_img_recons, x_value, y_value)

                y_value = np.asarray(loss_class.item()).reshape(1, )
                update_vis(viz, win_attrs_class, x_value, y_value)

                y_value = np.asarray(loss_decorr.item()).reshape(1, )
                update_vis(viz, win_decorr, x_value, y_value)

            # evaluate the model
            if count_update_step % 1000 == 0:
                print('\nUpdate step: {:d}'
                      '\nmean loss_img_recons: {:.4f}'
                      '\nmean loss_attrs_class: {:.4f}'
                      '\nmean loss_decorr: {:.4f}'.format(
                          count_update_step, loss_recons_image.item(),
                          loss_class.item(), loss_decorr.item()))

                # evaluation on validation set
                _, er_total_valid, \
                loss_decorr_valid, loss_recons_valid = evaluate_classification(encoder, decoder,
                                                                               valid_loader, params.n_valid, decorr_regul, device)
                print(
                    'Attribute Classification Error Rate on Validation Set: {:.2f}%'
                    .format(er_total_valid))
                print('Decorrelation Error on Each Mini-Batch: {:.4f}'.format(
                    loss_decorr_valid))
                print('Reconstruction Error on Each Sample: {:.4f}'.format(
                    loss_recons_valid))

                if count_update_step == 1000:
                    viz = Visdom()
                    x_value = np.asarray(count_update_step).reshape(1, )
                    x_label = 'Training Step'

                    y_value = np.asarray(loss_recons_valid).reshape(1, )
                    y_label = 'Reconstruction Error'
                    title = 'Valid Image Reconstruction Error'
                    legend = ['Recons_Error']
                    win_img_recons_valid = creat_vis_plot(
                        viz, x_value, y_value, x_label, y_label, title, legend)

                    y_value = np.asarray(er_total_valid).reshape(1, )
                    y_label = 'Classification Error'
                    title = 'Valid Attr. Classification Error'
                    legend = ['Class_Error']
                    win_attrs_class_valid = creat_vis_plot(
                        viz, x_value, y_value, x_label, y_label, title, legend)

                    y_value = np.asarray(loss_decorr_valid).reshape(1, )
                    y_label = 'Decorrelation Error'
                    title = 'Valid Decorrelation Error'
                    legend = ['Decorr_Error']
                    win_decorr_valid = creat_vis_plot(viz, x_value, y_value,
                                                      x_label, y_label, title,
                                                      legend)
                else:
                    x_value = np.asarray(count_update_step).reshape(1, )

                    y_value = np.asarray(loss_recons_valid).reshape(1, )
                    update_vis(viz, win_img_recons_valid, x_value, y_value)

                    y_value = np.asarray(er_total_valid).reshape(1, )
                    update_vis(viz, win_attrs_class_valid, x_value, y_value)

                    y_value = np.asarray(loss_decorr_valid).reshape(1, )
                    update_vis(viz, win_decorr_valid, x_value, y_value)

                encoder.train()
                decoder.train()

            # save the midterm model sates
            if count_update_step % 5000 == 0:
                torch.save(
                    encoder.state_dict(), save_dir + '/encoder_step' +
                    str(count_update_step) + '.pt')
                torch.save(
                    decoder.state_dict(), save_dir + '/decoder_step' +
                    str(count_update_step) + '.pt')
            if (10000 < count_update_step < 20000) and (count_update_step %
                                                        1000 == 0):
                torch.save(
                    encoder.state_dict(), save_dir + '/encoder_step' +
                    str(count_update_step) + '.pt')
                torch.save(
                    decoder.state_dict(), save_dir + '/decoder_step' +
                    str(count_update_step) + '.pt')

            lr_scheduler_encoder.step()
            lr_scheduler_decoder.step()

    # save the whole model
    torch.save(encoder.state_dict(), save_dir + '/encoder_final.pt')
    torch.save(decoder.state_dict(), save_dir + '/decoder_final.pt')

    return count_update_step
示例#3
0
def training(params, encoder_y, encoder_z, decoder, discriminator,
             optim_encoder_y, optim_encoder_z, optim_decoder,
             optim_discriminator, lr_scheduler_encoder_z, lr_scheduler_decoder,
             lr_scheduler_dis, device, attrs_class, decorr_regul, train_loader,
             valid_loader, margin, equilibrium, save_dir):

    ################################################
    # training Encoder_Y
    ################################################
    encoder_y.train()
    count_update_step_class = 0

    for epoch in range(params.n_epochs_EncY):

        for batch_idx, sample_batched in enumerate(train_loader):
            data, label = sample_batched['image'], sample_batched['attributes']

            batch_x = data.to(device)
            batch_y = label.to(device)

            optim_encoder_y.zero_grad()
            y_attrs = encoder_y(batch_x)
            y_attrs_sigmoid = torch.sigmoid(y_attrs)
            loss_class = attrs_class(y_attrs_sigmoid, batch_y)
            loss_class.backward()
            optim_encoder_y.step()

            count_update_step_class += 1
            if count_update_step_class % 500 == 0:
                # evaluation on validation set
                er_each_attr_valid, er_total_valid = evaluate_class(
                    encoder_y, valid_loader, params.n_valid, device)
                print(
                    '\nUpdate step: {:d} '
                    '\nAttribute Classification Error Rate on Validation Set: {:.4f}%'
                    .format(count_update_step_class, er_total_valid))

                if count_update_step_class == 500:
                    viz = Visdom()
                    x_value = np.asarray(count_update_step_class).reshape(1, )
                    x_label = 'Training Step'
                    y_value = np.asarray(er_total_valid).reshape(1, )
                    y_label = 'Classification Error'
                    title = 'Valid Attr. Classification Error'
                    legend = ['Class_Error']
                    win_attrs_class_valid = creat_vis_plot(
                        viz, x_value, y_value, x_label, y_label, title, legend)
                else:
                    x_value = np.asarray(count_update_step_class).reshape(1, )
                    y_value = np.asarray(er_total_valid).reshape(1, )
                    update_vis(viz, win_attrs_class_valid, x_value, y_value)

                encoder_y.train()

            # save the midterm model sate
            if count_update_step_class % 1000 == 0:
                torch.save(
                    encoder_y.state_dict(), save_dir + '/encoder_y_step' +
                    str(count_update_step_class) + '.pt')

    _, er_total_valid = evaluate_class(encoder_y, valid_loader, params.n_valid,
                                       device)
    print(
        '\nFinal Attribute Classification Error Rate on Validation Set: {:.4f}%'
        .format(er_total_valid))
    torch.save(encoder_y.state_dict(), save_dir + '/encoder_y_final.pt')

    ################################################
    # training Encoder_Z, Decoder, and Discriminator
    ################################################
    encoder_y.eval()
    encoder_z.train()
    decoder.train()
    discriminator.train()

    count_update_step = 0

    for epoch in range(params.n_epochs):

        for batch_idx, sample_batched in enumerate(train_loader):
            data, label = sample_batched['image'], sample_batched['attributes']

            batch_x = data.to(device)

            count_update_step += 1

            ############################
            # (1) update discriminator
            ############################
            for p in discriminator.parameters():
                p.requires_grad_(True)
            for p in decoder.parameters():
                p.requires_grad_(False)
            for p in encoder_z.parameters():
                p.requires_grad_(False)
            optim_discriminator.zero_grad()

            y_attrs = encoder_y(batch_x)
            y_attrs_sigmoid = torch.sigmoid(y_attrs)
            z_latent = encoder_z(batch_x)
            x_recons = decoder(z_latent.detach(), y_attrs_sigmoid)
            # using x_recons as fake data
            dis_output = discriminator(x_recons.detach(), batch_x, mode='GAN')
            dis_output_sampled = dis_output[:batch_x.size(0)]
            dis_output_original = dis_output[batch_x.size(0):]

            # GAN loss
            dis_original = -torch.log(dis_output_original + 1e-3)
            dis_sampled = -torch.log(1 - dis_output_sampled + 1e-3)

            loss_discriminator = torch.mean(dis_original) + torch.mean(
                dis_sampled)

            train_dis = True
            train_dec = True
            if ((torch.mean(dis_original)).item() > equilibrium + margin) \
                    or ((torch.mean(dis_sampled)).item() > equilibrium + margin):
                train_dec = False
            if ((torch.mean(dis_original)).item() < equilibrium - margin) \
                    or ((torch.mean(dis_sampled)).item() < equilibrium - margin):
                train_dis = False
            if train_dec is False and train_dis is False:
                train_dis = True
                train_dec = True
            if train_dis:
                loss_discriminator.backward()
                optim_discriminator.step()

            ############################
            # (2) update decoder
            ############################
            for p in decoder.parameters():
                p.requires_grad_(True)
            for p in discriminator.parameters():
                p.requires_grad_(False)
            optim_decoder.zero_grad()

            y_attrs = encoder_y(batch_x)
            y_attrs_sigmoid = torch.sigmoid(y_attrs)
            z_latent = encoder_z(batch_x)
            x_recons = decoder(z_latent.detach(), y_attrs_sigmoid)
            mid_repre = discriminator(x_recons, batch_x, mode='REC')
            mid_repre_recons = mid_repre[:batch_x.size(0)]
            mid_repre_original = mid_repre[batch_x.size(0):]
            # using x_recons as fake data
            dis_output = discriminator(x_recons, batch_x, mode='GAN')
            dis_output_sampled = dis_output[:batch_x.size(0)]
            dis_output_original = dis_output[batch_x.size(0):]

            # image reconstruction loss
            loss_recons_image = torch.mean(
                0.5 * (batch_x.view(len(batch_x), -1) -
                       x_recons.view(len(x_recons), -1))**2)
            # feature reconstruction loss
            loss_recons_feature = torch.mean(
                0.5 * (mid_repre_original - mid_repre_recons)**2)
            # GAN loss
            dis_original = -torch.log(dis_output_original + 1e-3)
            dis_sampled = -torch.log(1 - dis_output_sampled + 1e-3)
            loss_discriminator = torch.mean(dis_original) + torch.mean(
                dis_sampled)

            loss_decoder = 1 * loss_recons_image + params.lambda_recons * loss_recons_feature - \
                           params.lambda_dis * loss_discriminator

            train_dis = True
            train_dec = True
            if ((torch.mean(dis_original)).item() > equilibrium + margin) \
                    or ((torch.mean(dis_sampled)).item() > equilibrium + margin):
                train_dec = False
            if ((torch.mean(dis_original)).item() < equilibrium - margin) \
                    or ((torch.mean(dis_sampled)).item() < equilibrium - margin):
                train_dis = False
            if train_dec is False and train_dis is False:
                train_dis = True
                train_dec = True
            if train_dec:
                loss_decoder.backward()
                optim_decoder.step()

            ############################
            # (3) update encoder_z
            ############################
            for p in encoder_z.parameters():
                p.requires_grad_(True)
            for p in decoder.parameters():
                p.requires_grad_(False)
            optim_encoder_z.zero_grad()

            y_attrs = encoder_y(batch_x)
            y_attrs_sigmoid = torch.sigmoid(y_attrs)
            z_latent = encoder_z(batch_x)
            x_recons = decoder(z_latent, y_attrs_sigmoid)
            mid_repre = discriminator(x_recons, batch_x, mode='REC')
            mid_repre_recons = mid_repre[:batch_x.size(0)]
            mid_repre_original = mid_repre[batch_x.size(0):]

            # decorrelation loss
            start_time = time.time()
            loss_decorr = decorr_regul(y_attrs_sigmoid, z_latent)
            end_time = time.time()
            print(
                'Time cost of computing decorr_regul: batch_id={:d}, time={:.9f}'
                .format(batch_idx, end_time - start_time))
            # image reconstruction loss
            loss_recons_image = torch.mean(
                0.5 * (batch_x.view(len(batch_x), -1) -
                       x_recons.view(len(x_recons), -1))**2)
            # feature reconstruction loss
            loss_recons_feature = torch.mean(
                0.5 * (mid_repre_original - mid_repre_recons)**2)

            loss_encoder_z = 1 * loss_recons_image + params.lambda_recons * loss_recons_feature + \
                             get_lambda(params.lambda_decorr, params.lambda_schedule, count_update_step) * loss_decorr \

            loss_encoder_z.backward()
            optim_encoder_z.step()

            # visualize losses
            if count_update_step == 1:
                viz = Visdom()
                x_value = np.asarray(count_update_step).reshape(1, )
                x_label = 'Training Step'

                y_value = np.asarray(loss_recons_image.item()).reshape(1, )
                y_label = 'MSE'
                title = 'Image Reconstruction Loss'
                legend = ['MSE']
                win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label,
                                                y_label, title, legend)

                y_value = np.asarray(loss_recons_feature.item()).reshape(1, )
                y_label = 'MSE'
                title = 'Feature Reconstruction Loss'
                legend = ['MSE']
                win_feature_recons = creat_vis_plot(viz, x_value, y_value,
                                                    x_label, y_label, title,
                                                    legend)

                y_value = np.column_stack(
                    (np.asarray(loss_discriminator.item()),
                     np.asarray(loss_decoder.item())))
                y_label = 'Loss'
                title = 'Discriminator and Decoder Losses'
                legend = ['Loss_Dis', 'Loss_Dec']
                win_dis_gen = creat_vis_plot(viz, x_value, y_value, x_label,
                                             y_label, title, legend)

                y_value = np.asarray(loss_decorr.item()).reshape(1, )
                y_label = 'Loss'
                title = 'Decorrelation Loss'
                legend = ['Loss_Decorr']
                win_decorr = creat_vis_plot(viz, x_value, y_value, x_label,
                                            y_label, title, legend)

            elif count_update_step % 50 == 0:
                x_value = np.asarray(count_update_step).reshape(1, )

                y_value = np.asarray(loss_recons_image.item()).reshape(1, )
                update_vis(viz, win_img_recons, x_value, y_value)

                y_value = np.asarray(loss_recons_feature.item()).reshape(1, )
                update_vis(viz, win_feature_recons, x_value, y_value)

                y_value = np.column_stack(
                    (np.asarray(loss_discriminator.item()),
                     np.asarray(loss_decoder.item())))
                update_vis(viz, win_dis_gen, x_value, y_value)

                y_value = np.asarray(loss_decorr.item()).reshape(1, )
                update_vis(viz, win_decorr, x_value, y_value)

            # evaluate the model
            if count_update_step % 1000 == 0:
                print('\nUpdate step: {:d}'
                      '\nmean loss_img_recons: {:.4f}'
                      '\nmean loss_feature_recons: {:.4f}'
                      '\nmean loss_GAN_dis: {:.4f}'
                      '\nmean loss_decoder: {:.4f}'
                      '\nmean loss_decorr: {:.4f}'.format(
                          count_update_step, loss_recons_image.item(),
                          loss_recons_feature.item(),
                          loss_discriminator.item(), loss_decoder.item(),
                          loss_decorr.item()))

                # evaluation on validation set
                loss_decorr_valid, loss_recons_valid = evaluate_learning(
                    encoder_y, encoder_z, decoder, valid_loader,
                    params.n_valid, decorr_regul, device)
                print('Decorrelation Error on Each Mini-Batch: {:.4f}'.format(
                    loss_decorr_valid))
                print('Reconstruction Error on Each Sample: {:.4f}'.format(
                    loss_recons_valid))

                if count_update_step == 1000:
                    viz = Visdom()
                    x_value = np.asarray(count_update_step).reshape(1, )
                    x_label = 'Training Step'

                    y_value = np.asarray(loss_recons_valid).reshape(1, )
                    y_label = 'Reconstruction Error'
                    title = 'Valid Image Reconstruction Error'
                    legend = ['Recons_Error']
                    win_img_recons_valid = creat_vis_plot(
                        viz, x_value, y_value, x_label, y_label, title, legend)

                    y_value = np.asarray(loss_decorr_valid).reshape(1, )
                    y_label = 'Decorrelation Error'
                    title = 'Valid Decorrelation Error'
                    legend = ['Decorr_Error']
                    win_decorr_valid = creat_vis_plot(viz, x_value, y_value,
                                                      x_label, y_label, title,
                                                      legend)
                else:
                    x_value = np.asarray(count_update_step).reshape(1, )

                    y_value = np.asarray(loss_recons_valid).reshape(1, )
                    update_vis(viz, win_img_recons_valid, x_value, y_value)

                    y_value = np.asarray(loss_decorr_valid).reshape(1, )
                    update_vis(viz, win_decorr_valid, x_value, y_value)

                encoder_z.train()
                decoder.train()

            # save the midterm model sates
            if count_update_step % 5000 == 0:
                torch.save(
                    encoder_z.state_dict(), save_dir + '/encoder_z_step' +
                    str(count_update_step) + '.pt')
                torch.save(
                    decoder.state_dict(), save_dir + '/decoder_step' +
                    str(count_update_step) + '.pt')
                torch.save(
                    discriminator.state_dict(), save_dir +
                    '/discriminator_step' + str(count_update_step) + '.pt')

            lr_scheduler_encoder_z.step()
            lr_scheduler_decoder.step()
            lr_scheduler_dis.step()

    torch.save(encoder_z.state_dict(), save_dir + '/encoder_z_final.pt')
    torch.save(decoder.state_dict(), save_dir + '/decoder_final.pt')
    torch.save(discriminator.state_dict(),
               save_dir + '/discriminator_final.pt')

    return count_update_step
示例#4
0
def training(params, encoder, decoder, optim_encoder, optim_decoder, device,
             digit_class, decorr_regul, train_loader, valid_loader, save_dir):

    encoder.train()
    decoder.train()

    count_update_step = 0
    indices = torch.LongTensor(params.batch_size, 1)
    labels_onehot = torch.FloatTensor(params.batch_size, params.n_class)

    for i in range(params.n_epochs):

        for batch_idx, (batch_x, batch_y) in enumerate(train_loader):

            batch_x = batch_x.to(device)

            # convert the labels into one-hot form vectors
            indices.zero_()
            indices = batch_y.view(-1, 1)
            labels_onehot.zero_()
            labels_onehot.scatter_(1, indices, 1)
            batch_y_onehot = labels_onehot.to(device)
            batch_y = batch_y.to(device)

            count_update_step += 1

            ############################
            # (1) update encoder
            ############################
            for p in encoder.parameters():
                p.requires_grad_(True)
            for p in decoder.parameters():
                p.requires_grad_(False)
            optim_encoder.zero_grad()

            y_class, z_latent = encoder(batch_x)
            x_recons = decoder(batch_y_onehot, z_latent)

            # classification loss
            loss_class = digit_class(y_class, batch_y)
            # decorrelation loss
            y_class_softmax = F.softmax(y_class, 1)
            loss_decorr = decorr_regul(y_class_softmax, z_latent)
            # image reconstruction loss
            loss_recons_image = torch.mean(
                0.5 * (batch_x.view(len(batch_x), -1) -
                       x_recons.view(len(x_recons), -1))**2)
            loss_encoder = params.lambda_class * loss_class + params.lambda_decorr * loss_decorr \
                           + params.lambda_recons * loss_recons_image

            loss_encoder.backward()
            optim_encoder.step()

            ############################
            # (2) update decoder
            ############################
            for p in decoder.parameters():
                p.requires_grad_(True)
            for p in encoder.parameters():
                p.requires_grad_(False)
            optim_decoder.zero_grad()

            _, z_latent = encoder(batch_x)
            x_recons = decoder(batch_y_onehot, z_latent)

            # image reconstruction loss
            loss_recons_image = torch.mean(
                0.5 * (batch_x.view(len(batch_x), -1) -
                       x_recons.view(len(x_recons), -1))**2)
            loss_decoder = params.lambda_recons * loss_recons_image

            loss_decoder.backward()
            optim_decoder.step()

            # visualize losses
            if count_update_step == 1:
                viz = Visdom()
                x_value = np.asarray(count_update_step).reshape(1, )
                x_label = 'Training Step'

                y_value = np.asarray(loss_recons_image.item()).reshape(1, )
                y_label = 'MSE'
                title = 'Image Reconstruction Loss'
                legend = ['MSE']
                win_img_recons = creat_vis_plot(viz, x_value, y_value, x_label,
                                                y_label, title, legend)

                y_value = np.asarray(loss_class.item()).reshape(1, )
                y_label = 'Loss'
                title = 'Cross Entropy Loss'
                legend = ['Loss_BCE']
                win_attrs_class = creat_vis_plot(viz, x_value, y_value,
                                                 x_label, y_label, title,
                                                 legend)

                y_value = np.asarray(loss_decorr.item()).reshape(1, )
                y_label = 'Loss'
                title = 'Decorrelation Loss'
                legend = ['Loss_Decorr']
                win_decorr = creat_vis_plot(viz, x_value, y_value, x_label,
                                            y_label, title, legend)

            elif count_update_step % 50 == 0:
                x_value = np.asarray(count_update_step).reshape(1, )

                y_value = np.asarray(loss_recons_image.item()).reshape(1, )
                update_vis(viz, win_img_recons, x_value, y_value)

                y_value = np.asarray(loss_class.item()).reshape(1, )
                update_vis(viz, win_attrs_class, x_value, y_value)

                y_value = np.asarray(loss_decorr.item()).reshape(1, )
                update_vis(viz, win_decorr, x_value, y_value)

            # evaluate the model
            if count_update_step % 1000 == 0:
                print('\nUpdate step: {:d}'
                      '\nmean loss_img_recons: {:.4f}'
                      '\nmean loss_class: {:.4f}'
                      '\nmean loss_decorr: {:.4f}'.format(
                          count_update_step, loss_recons_image.item(),
                          loss_class.item(), loss_decorr.item()))

                # evaluation on validation set
                loss_recons, loss_decorr, class_error_rate = evaluate_classification(
                    params, encoder, decoder, valid_loader, params.n_valid,
                    decorr_regul, device)
                print('Classification Error Rate on Validation Set: {:.2f}%'.
                      format(class_error_rate))
                print('Decorrelation Error on Each Mini-Batch: {:.4f}'.format(
                    loss_decorr))
                print('Reconstruction Error on Each Sample: {:.4f}'.format(
                    loss_recons))

                if count_update_step == 1000:
                    viz = Visdom()
                    x_value = np.asarray(count_update_step).reshape(1, )
                    x_label = 'Training Step'

                    y_value = np.asarray(loss_recons).reshape(1, )
                    y_label = 'Reconstruction Error'
                    title = 'Valid Image Reconstruction Error'
                    legend = ['Recons_Error']
                    win_img_recons_valid = creat_vis_plot(
                        viz, x_value, y_value, x_label, y_label, title, legend)

                    y_value = np.asarray(class_error_rate).reshape(1, )
                    y_label = 'Classification Error'
                    title = 'Valid Classification Error'
                    legend = ['Class_Error']
                    win_class_valid = creat_vis_plot(viz, x_value, y_value,
                                                     x_label, y_label, title,
                                                     legend)

                    y_value = np.asarray(loss_decorr).reshape(1, )
                    y_label = 'Decorrelation Error'
                    title = 'Valid Decorrelation Error'
                    legend = ['Decorr_Error']
                    win_decorr_valid = creat_vis_plot(viz, x_value, y_value,
                                                      x_label, y_label, title,
                                                      legend)
                else:
                    x_value = np.asarray(count_update_step).reshape(1, )

                    y_value = np.asarray(loss_recons).reshape(1, )
                    update_vis(viz, win_img_recons_valid, x_value, y_value)

                    y_value = np.asarray(class_error_rate).reshape(1, )
                    update_vis(viz, win_class_valid, x_value, y_value)

                    y_value = np.asarray(loss_decorr).reshape(1, )
                    update_vis(viz, win_decorr_valid, x_value, y_value)

                encoder.train()
                decoder.train()

            # save the midterm model sates
            if count_update_step % 5000 == 0:
                torch.save(
                    encoder.state_dict(), save_dir + '/encoder_step' +
                    str(count_update_step) + '.pt')
                torch.save(
                    decoder.state_dict(), save_dir + '/decoder_step' +
                    str(count_update_step) + '.pt')

            # lr_scheduler_encoder.step()
            # lr_scheduler_decoder.step()

    # save the whole model
    torch.save(encoder.state_dict(), save_dir + '/encoder_final.pt')
    torch.save(decoder.state_dict(), save_dir + '/decoder_final.pt')

    return count_update_step
示例#5
0
                legend = ['MSE']
                win_feature_recons = creat_vis_plot(viz, x_value, y_value,
                                                    x_label, y_label, title, legend)

                y_value = np.column_stack((np.asarray(loss_discriminator.item()), np.asarray(loss_decoder.item())))
                y_label = 'Loss'
                title = 'Discriminator and Decoder Losses'
                legend = ['Loss_Dis', 'Loss_Dec']
                win_dis_gen = creat_vis_plot(viz, x_value, y_value,
                                             x_label, y_label, title, legend)

            elif count_update_step % 50 == 0:
                x_value = np.asarray(count_update_step).reshape(1, )

                y_value = np.asarray(torch.mean(nle_value).item()).reshape(1, )
                update_vis(viz, win_img_recons, x_value, y_value)

                y_value = np.asarray(torch.mean(mse_value).item()).reshape(1, )
                update_vis(viz, win_feature_recons, x_value, y_value)

                y_value = np.column_stack((np.asarray(loss_discriminator.item()), np.asarray(loss_decoder.item())))
                update_vis(viz, win_dis_gen, x_value, y_value)

            # evaluate the model
            if count_update_step % 1000 == 0:
                print('\nUpdate step: {:d}'
                      '\nmean loss_img_recons: {:.4f}'
                      '\nmean loss_feature_recons: {:.4f}'
                      '\nmean loss_GAN_dis: {:.4f}'
                      '\nmean loss_decoder: {:.4f}'.format(
                    count_update_step, torch.mean(nle_value).item(), torch.mean(mse_value).item(),