1200字范文,内容丰富有趣,写作的好帮手!
1200字范文 > 生成对抗网络GANs

生成对抗网络GANs

时间:2023-03-05 15:14:08

相关推荐

生成对抗网络GANs

生成对抗网络GANs(Generative Adversarial Nets

from datetime import datetimeimport tensorflow as tfimport numpy as npimport matplotlib.pyplot as pltimport matplotlib.gridspec as gridspecfrom tensorflow.examples.tutorials.mnist import input_dataBATCH_SIZE = 128LEARNING_RATE = 1e-4Z_DIM = 100IMAGE_W = 28IMAGE_H = 28model_dir = 'model_gan'x_in = tf.placeholder(tf.float32, shape=[None, 784])def load_mnist():return input_data.read_data_sets("./MNIST_data", one_hot=True)mnist = load_mnist()def get_W_b(input_dim, output_dim, name):W = tf.Variable(tf.random_normal([input_dim, output_dim], stddev=0.02), name=name.replace('_b', ''))b = tf.Variable(tf.zeros([output_dim], tf.float32), name=name.replace('_W', ''))return W, btmp = 256class GAN(object):def __init__(self, lr=LEARNING_RATE, batch_size=BATCH_SIZE, z_dim=Z_DIM):self.lr = lrself.batch_size = batch_sizeself.z_dim = z_dim# 生成器的权重self.gen_W1, self.gen_b1 = get_W_b(z_dim, tmp, 'gen_W_b_1')self.gen_W2, self.gen_b2 = get_W_b(tmp, IMAGE_H * IMAGE_W, 'gen_W_b_2')# 判别器的权重self.discrim_W1, self.discrim_b1 = get_W_b(IMAGE_H * IMAGE_W, tmp, 'discrim_W_b_1')self.discrim_W2, self.discrim_b2 = get_W_b(tmp, 1, 'discrim_W_b_2')# 判别器def discriminator(self, x):d_h1 = tf.nn.relu(tf.add(tf.matmul(x, self.discrim_W1), self.discrim_b1))d_h2 = tf.add(tf.matmul(d_h1, self.discrim_W2), self.discrim_b2)return tf.nn.sigmoid(d_h2)# 生成器def generator(self, z):g_h1 = tf.nn.relu(tf.add(tf.matmul(z, self.gen_W1), self.gen_b1))g_h2 = tf.add(tf.matmul(g_h1, self.gen_W2), self.gen_b2)return tf.nn.sigmoid(g_h2)# 建立模型def build_model(self):z_sample = np.random.uniform(-1., 1., size=[self.batch_size, self.z_dim]).astype('float32')g_image = self.generator(z_sample)d_real = self.discriminator(x_in)d_fake = self.discriminator(g_image)d_cost = -tf.reduce_mean(tf.log(d_real) + tf.log(1. - d_fake))g_cost = -tf.reduce_mean(tf.log(d_fake))return d_cost, g_cost, tf.reduce_mean(d_real), tf.reduce_mean(d_fake)# 画图def plot_grid(samples):fig = plt.figure(figsize=(4, 4))gs = gridspec.GridSpec(4, 4)gs.update(wspace=0.05, hspace=0.05)for i, sample in enumerate(samples):ax = plt.subplot(gs[i])plt.axis('off')ax.set_xticklabels([])ax.set_yticklabels([])ax.set_aspect('equal')plt.imshow(sample.reshape(IMAGE_H, IMAGE_W), cmap='Greys_r')return fig# 训练def train():with tf.Session() as sess:gan = GAN()discrim_vars = list(filter(lambda x: x.name.startswith('discrim'), tf.trainable_variables()))gen_vars = list(filter(lambda x: x.name.startswith('gen'), tf.trainable_variables()))d_cost, g_cost, d_real, d_fake = gan.build_model()optimizer = tf.train.AdamOptimizer(learning_rate=LEARNING_RATE)d_opt = optimizer.minimize(d_cost, var_list=discrim_vars)g_opt = optimizer.minimize(g_cost, var_list=gen_vars)saver = tf.train.Saver()checkpoint = tf.train.latest_checkpoint(model_dir)if checkpoint:saver.restore(sess, checkpoint) # 从模型中读取数据print('checkpoint: {}'.format(checkpoint))else:# 变量初始化sess.run(tf.global_variables_initializer())print("Started training {}".format(datetime.now().isoformat()[11:]))plot_index = 0for step in range(100000):batch_x, _ = mnist.train.next_batch(BATCH_SIZE)sess.run(d_opt, feed_dict={x_in: batch_x})sess.run(g_opt, feed_dict={x_in: batch_x})# 每1000个step保存一次图片if step % 1000 == 0:batch_x, _ = mnist.train.next_batch(BATCH_SIZE)d_cost_, d_real_, d_fake_ = sess.run([d_cost, d_real, d_fake], feed_dict={x_in: batch_x})g_cost_ = sess.run(g_cost, feed_dict={x_in: batch_x})print("step:{} Discriminator Loss {} Generator loss {} d_real:{} d_feak:{}".format(step, d_cost_,g_cost_, d_real_,d_fake_))z_sample = np.random.uniform(-1., 1., size=[16, Z_DIM]).astype('float32')g_image = sess.run(gan.generator(z_sample))fig = plot_grid(g_image)plt.savefig('D:\project\生成对抗网络\img\{}.png'.format(str(plot_index).zfill(4)), bbox_inches='tight')plot_index += 1plt.close(fig)# 保存模型saver.save(sess, "{}/model_gan.model".format(model_dir), global_step=step)print("Ended training {}".format(datetime.now().isoformat()[11:]))if __name__ == "__main__":train()

本内容不代表本网观点和政治立场,如有侵犯你的权益请联系我们处理。
网友评论
网友评论仅供其表达个人看法,并不表明网站立场。