Source code for zoo.tfpark.gan.gan_estimator

#
# Copyright 2018 Analytics Zoo Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import inspect
import os
import tempfile

import tensorflow as tf

from zoo.tfpark import TFOptimizer
# todo make it inherit Estimator
from zoo.tfpark.zoo_optimizer import FakeOptimMethod
from zoo.util import nest


[docs]class GANEstimator(object): def __init__(self, generator_fn, discriminator_fn, generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, generator_steps=1, discriminator_steps=1, model_dir=None, session_config=None, ): from zoo.tfpark import ZooOptimizer assert isinstance(generator_optimizer, ZooOptimizer),\ "generator_optimizer should be a ZooOptimizer" assert isinstance(discriminator_optimizer, ZooOptimizer),\ "discriminator_optimizer should be a ZooOptimizer" self._generator_fn = generator_fn self._discriminator_fn = discriminator_fn self._generator_loss_fn = generator_loss_fn self._discriminator_loss_fn = discriminator_loss_fn self._generator_steps = generator_steps self._discriminator_steps = discriminator_steps self._gen_opt = generator_optimizer self._dis_opt = discriminator_optimizer self._session_config = session_config if model_dir is None: folder = tempfile.mkdtemp() self.checkpoint_path = os.path.join(folder, "model") self.model_dir = folder else: self.checkpoint_path = os.path.join(model_dir, "model") self.model_dir = model_dir @staticmethod def _call_fn_maybe_with_counter(fn, counter, *args): fn_args = inspect.getargspec(fn).args if "counter" in fn_args: return fn(*args, counter=counter) else: return fn(*args)
[docs] def train(self, input_fn, end_trigger): with tf.Graph().as_default() as g: dataset = input_fn() generator_inputs = dataset.tensors[0] real_data = dataset.tensors[1] counter = tf.train.get_or_create_global_step() period = self._discriminator_steps + self._generator_steps is_discriminator_phase = tf.less(tf.mod(counter, period), self._discriminator_steps) with tf.variable_scope("Generator"): gen_data = self._call_fn_maybe_with_counter(self._generator_fn, counter, generator_inputs) with tf.variable_scope("Discriminator"): fake_d_outputs = self._call_fn_maybe_with_counter(self._discriminator_fn, counter, gen_data, generator_inputs) with tf.variable_scope("Discriminator", reuse=True): real_d_outputs = self._call_fn_maybe_with_counter(self._discriminator_fn, counter, real_data, generator_inputs) with tf.name_scope("Generator_loss"): generator_loss = self._call_fn_maybe_with_counter(self._generator_loss_fn, counter, fake_d_outputs) gen_reg_loss = tf.losses.get_regularization_loss("Generator") generator_loss = generator_loss + gen_reg_loss with tf.name_scope("Discriminator_loss"): discriminator_loss = self._call_fn_maybe_with_counter(self._discriminator_loss_fn, counter, real_d_outputs, fake_d_outputs) dis_reg_loss = tf.losses.get_regularization_loss("Discriminator") discriminator_loss = discriminator_loss + dis_reg_loss generator_variables = tf.trainable_variables("Generator") discriminator_variables = tf.trainable_variables("Discriminator") def run_gen_compute(): gen_grads_vars = self._gen_opt.compute_gradients(generator_loss, var_list=generator_variables) gen_grads = [grad for grad, var in gen_grads_vars] dis_grads = [tf.zeros_like(var) for var in discriminator_variables] return gen_grads + dis_grads def run_dis_compute(): dis_grads_vars = self._gen_opt.compute_gradients(discriminator_loss, var_list=discriminator_variables) dis_grads = [grad for grad, var in dis_grads_vars] gen_gards = [tf.zeros_like(var) for var in generator_variables] return gen_gards + dis_grads grads = tf.cond(is_discriminator_phase, run_dis_compute, run_gen_compute) grads_vars = list(zip(grads, generator_variables + discriminator_variables)) gen_grads_vars = grads_vars[:len(generator_variables)] dis_grads_vars = grads_vars[len(generator_variables):] grads = [grad for grad, var in grads_vars] _train_op = tf.cond(is_discriminator_phase, lambda: self._dis_opt.apply_gradients(dis_grads_vars), lambda: self._gen_opt.apply_gradients(gen_grads_vars)) variables = generator_variables + discriminator_variables loss = tf.cond(is_discriminator_phase, lambda: discriminator_loss, lambda: generator_loss) with tf.control_dependencies([_train_op]): increase_counter = tf.assign_add(counter, 1) with tf.control_dependencies([increase_counter]): train_op = tf.no_op() with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver() kpt = tf.train.latest_checkpoint(self.model_dir) if kpt is not None: saver.restore(sess, kpt) opt = TFOptimizer._from_grads(loss, sess, inputs=nest.flatten(dataset._original_tensors), labels=[], grads=grads, variables=variables, dataset=dataset, optim_method=FakeOptimMethod(), session_config=self._session_config, model_dir=os.path.join(self.model_dir, "tmp"), train_op=train_op) opt.optimize(end_trigger) saver = tf.train.Saver() saver.save(sess, self.checkpoint_path, global_step=counter)