Example #1
0
            one_side_errG = one_sided(f_enc_X.mean(0) - f_enc_Y.mean(0))

            errG = torch.sqrt(mmd2_G) + lambda_rg * one_side_errG
            errG.backward(one)
            optimizerG.step()

            gen_iterations += 1

        run_time = (timeit.default_timer() - time) / 60.0
        print(
            '[%3d/%3d][%3d/%3d] [%5d] (%.2f m) MMD2_D %.6f hinge %.6f L2_AE_X %.6f L2_AE_Y %.6f loss_D %.6f Loss_G %.6f f_X %.6f f_Y %.6f |gD| %.4f |gG| %.4f'
            % (t, args.max_iter, i, len(trn_loader), gen_iterations, run_time,
               mmd2_D.data[0], one_side_errD.data[0], L2_AE_X_D.data[0],
               L2_AE_Y_D.data[0], errD.data[0], errG.data[0],
               f_enc_X_D.mean().data[0], f_enc_Y_D.mean().data[0],
               base_module.grad_norm(netD), base_module.grad_norm(netG)))

        if gen_iterations % 500 == 0:
            y_fixed = netG(fixed_noise)
            y_fixed.data = y_fixed.data.mul(0.5).add(0.5)
            f_dec_X_D = f_dec_X_D.view(f_dec_X_D.size(0), args.nc,
                                       args.image_size, args.image_size)
            f_dec_X_D.data = f_dec_X_D.data.mul(0.5).add(0.5)
            vutils.save_image(
                y_fixed.data,
                '{0}/fake_samples_{1}.png'.format(args.experiment,
                                                  gen_iterations))
            vutils.save_image(
                f_dec_X_D.data,
                '{0}/decode_samples_{1}.png'.format(args.experiment,
                                                    gen_iterations))
Example #2
0
            one_side_errG = one_sided(f_enc_X.mean(0) - f_enc_Y.mean(0))

            errG = torch.sqrt(mmd2_G) + lambda_rg * one_side_errG
            errG.backward(one)
            optimizerG.step()

            gen_iterations += 1

        run_time = (timeit.default_timer() - time) / 60.0
        print(
            '[%3d/%3d][%3d/%3d] [%5d] (%.2f m) MMD2_D %.6f hinge %.6f L2_AE_X %.6f L2_AE_Y %.6f loss_D %.6f Loss_G %.6f f_X %.6f f_Y %.6f |gD| %.4f |gG| %.4f'
            % (t, args.max_iter, i, len(trn_loader), gen_iterations, run_time,
               mmd2_D.data.item(), one_side_errD.data.item(),
               L2_AE_X_D.data.item(), L2_AE_Y_D.data.item(), errD.data.item(),
               errG.data.item(), f_enc_X_D.mean().data.item(),
               f_enc_Y_D.mean().data.item(), base_module.grad_norm(netD),
               base_module.grad_norm(netG)))

        if gen_iterations % 500 == 0:
            y_fixed = netG(fixed_noise)
            y_fixed.data = y_fixed.data.mul(0.5).add(0.5)
            f_dec_X_D = f_dec_X_D.view(f_dec_X_D.size(0), args.nc,
                                       args.image_size, args.image_size)
            f_dec_X_D.data = f_dec_X_D.data.mul(0.5).add(0.5)
            vutils.save_image(
                y_fixed.data,
                '{0}/fake_samples_{1}.png'.format(args.experiment,
                                                  gen_iterations))
            vutils.save_image(
                f_dec_X_D.data,
                '{0}/decode_samples_{1}.png'.format(args.experiment,
Example #3
0
                sys.exit('Finished diagnostics. I\'m out')

        # Do various logs and print summaries.
        run_time = (timeit.default_timer() - time) / 60.0
        if do_log(gen_iterations):
            if gen_iterations % 1000 == 0:
                print(args)
            # Print summary.
            print(('[Epoch %3d/%3d][Batch %3d/%3d] [%5d] (%.2f m) MMD2_D %.6f '
                   'hinge %.6f L2_AE_X %.6f L2_AE_Y %.6f loss_D %.6f Loss_G '
                   '%.6f f_X %.6f f_Y %.6f |gD| %.4f |gG| %.4f') %
                  (global_step, args.max_iter, batch_iter, len(trn_loader),
                   gen_iterations, run_time, mmd2_D.data[0],
                   one_side_errD.data[0], L2_AE_X_D.data[0], L2_AE_Y_D.data[0],
                   errD.data[0], errG.data[0], f_enc_X_D.mean().data[0],
                   f_enc_Y_D.mean().data[0], base_module.grad_norm(netD),
                   base_module.grad_norm(netG)))
            # Save metrics for the run.
            with open(
                    os.path.join(save_dir, 'log_x_eval_regression_error.txt'),
                    'a') as f:
                f.write('{:.6f}\n'.format(x_eval_error))
            with open(
                    os.path.join(save_dir,
                                 'log_x_eval_regression_error_1s.txt'),
                    'a') as f:
                f.write('{:.6f}\n'.format(x_eval_error_1s))
            with open(
                    os.path.join(save_dir,
                                 'log_x_eval_regression_error_0s.txt'),
                    'a') as f: