Skip to main content

Hitachi
Research & Development
Industrial AI blog

Influence estimation for generative adversarial networks

10 May 2021

Naoyuki Terashita

Naoyuki Terashita
Research & Development Group, Hitachi, Ltd.

Expanding applications [1, 2] of generative adversarial networks (GANs) makes improving the generative performance of models increasingly crucial. An effective approach to improve machine learning models is to identify training instances that “harm” the model’s performance. Recent studies [3, 4] replaced traditional manual screening of a dataset with “influence estimation.” They evaluated the harmfulness of a training instance based on how the performance is expected to change when the instance is removed from the dataset. An example of a harmful instance is a wrongly labeled instance (e.g., a “dog” image labeled as a “cat”). Influence estimation judges this “cat labeled dog image” as a harmful instance when the removal of “cat labeled dog image” is predicted to improve the performance (Figure 1)

Figure1

Figure 1: Example of a harmful instance in supervised learning

Previous studies have succeeded in identifying harmful instances in supervised learning, such as a classification task. Supervised learning is normally a one-model setting, in which there is a single model and thus the absence of a training instance directly affects the model. Extending previous approaches to GANs is not an easy task because GANs are trained on two-model settings; a generator tries to generate realistic samples and a discriminator tries to correctly classify whether the samples are real or generated or real. The training instances are thus only fed into the discriminator and they “indirectly” affect the generator (Figure 2).

Figure2

Figure 2: A major issue “indirect effect” of influence estimation for generative adversarial networks.

To solve this problem, we proposed a novel influence estimation method that uses the Jacobian of the gradient of the discriminator's loss with respect to the generator’s parameters (and vice versa) to trace how the absence of an instance in the discriminator’s training affects the generator. We also proposed to evaluate the harmfulness of an instance by estimating how a GAN evaluation metric is expected to improve when the instance is removed. We named this measure of the harmfulness “influence on GAN evaluation metric”.

The removal of suggested harmful instances based on “influence on GAN evaluation metric” was evaluated to see how it improved the performance of the generator. We tested two datasets: 2D multivariate Gaussian (2D-normal) and MNIST. Harmful instances in 2D-normal are identified based on influence on average log-likelihood (ALL). For MNIST, we estimated harmful instances based on influence on inception score (IS) [5] and Fréchet inception distance (FID) [6]. For both setups, instances were selected using baseline approaches: anomaly detection (Isolation Forest), influence on the discriminator loss (Disc. Loss), and random selection (Random). Figure 3 shows the average test GAN evaluation metrics across the repeated experiments with respect to the number of removed samples for each selection approach. For all the settings, our data cleansing approach statistically significantly improved the GAN evaluation metrics, showing higher performance compared to the baselines.

Figure3

Figure 3: (a) Average log likelihood (ALL), (b) Inception score (IS), and (c) Fréchet inception distance (FID) evaluated with a test dataset and latent variables after data cleansing. The larger values for ALL and IS, and a smaller value for FID indicate better generative performance.

We also examined the characteristics of harmful instances. Overall, harmful instances were found to belong to regions from which the generator sampled too frequently compared to the true distribution (Figures 4a and 5a). Removing those instances alleviated over sampling which moved the generator distribution of 2D-Normal closer to the true distribution in Figure 4b, and the generated MNIST samples changed from the image of digit 1 to that of other digits after data cleansing (highlighted samples in Figures 5b and 5c).

Figure4

Figure 4: Harmfulness of 2D-normal instances illustrated using (a) influence on ALL and (b) changes in the generator's distribution. (b) includes plots of the true distribution (True) and generator's distributions before (no removal) and after data cleansing (cleansed).

Figure5

Figure 5: (a) top 36 harmful MNIST instances predicted based on influence on FID, and the test generated samples (b) before and (c) after data cleansing.

This is the first research that enables influence estimation in GANs. Our proposal has the potential to not only improve generative performance but also provide an important scope of the interpretability of the training dynamics of GANs. For more details, we encourage you to read our paper, "Influence Estimation for Generative Adversarial Networks" which was presented at The Ninth International Conference on Learning Representations (ICLR 2021) on 4 May 2021.


References

[1]
Christian Ledig, Lucas Theis, Ferenc Huszar, Jose Caballero, Andrew Cunningham, Alejandro Acosta, Andrew Aitken, Alykhan Tejani, Johannes Totz, Zehan Wang, and Wenzhe Shi. Photo-realistic single image super-resolution using a generative adversarial network. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR), pp. 4681-4690, 2017.
[2]
Jiajun Wu, Chengkai Zhang, Tianfan Xue, Bill Freeman, and Josh Tenenbaum. Learning a probabilistic latent space of object shapes via 3d generative-adversarial modeling. In D.D. Lee, M. Sugiyama, U.V. Luxburg, I. Guyon, and R. Garnett, (Eds.), Advances in Neural Information Processing Systems 29, Proceeding of the 30th International Conference on Neural Information Processing Systems, Curran Associates, Inc., pp. 82-90, 2016.
[3]
Rajiv Khanna, Been Kim, Joydeep Ghosh, and Sanmi Koyejo. Interpreting black box predictions using fisher kernels. In Proceeding of Machine Learning Research, PLMR 89, pp. 3382-3390, 2019.
[4]
Satoshi Hara, Atsushi Nitanda, and Takanori Maehara. Data cleansing for models trained with SGD. In H. Wallach, H. Larochelle, A. Beygelzimer, F. d'Alché-Buc, E. Fox, and R. Garnett, (Eds.), Advances in Neural Information Processing Systems 32, Proceeding of the 33rd International Conference on Neural Information Processing Systems, Curran Associates, Inc., pp. 4215-4224, 2019.
[5]
Tim Salimans, Ian Goodfellow, Wojciech Zaremba, Vicki Cheung, Alec Radford, Xi Chen, and Xi Chen. Improved techniques for training GANs. In D.D. Lee, M. Sugiyama, U.V. Luxburg, I. Guyon, and R. Garnett, (Eds.), Advances in Neural Information Processing Systems 29, Proceedings of the 30th International Conference on Neural Information Processing (NIPS’16), Curran Associates, Inc., pp. 2234-2242, 2016.
[6]
Martin Heusel, Hubert Ramsauer, Thomas Unterthiner, Bernhard Nessler, and Sepp Hochreiter. GANs trained by a two time-scale update rule converge to a local nash equilibrium. In Advances in Neural Information Processing Systems 30, Proceeding of the 31st International Conference on Neural Information Processing Systems, pp 6629-6647, 2017.