VQ-GAN复现

news/2025/2/27 8:51:22

最近研究在自编码器,放一个复现的代码,移除了工程相关的代码,只保留了核心,有多卡accelerate就设置为True,没有就关了。

Decode 和 Encode 参考了stable diffusion的设计,Decode最后一层改成了方差和均值(也就是纯血VAE)特征图通过采样产生,再使用VQ量化特征图。图片最后还是有些胡,感觉是因为有些图像被压缩过,插值成256*256,或者jpeg格式的有损压缩导致了数据有噪声被学会了。

数据源:

Konachan动漫头像数据集_数据集-飞桨AI Studio星河社区

效果图

epoch 0 step 100

epoch 6 step 10000

epoch 50 step 85000epoch 100 176700

模型代码 

import math

import numpy as np
import torch
from torch import nn
from torch.nn import functional as F


class ConvBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1):
        super(ConvBlock, self).__init__()
        self.conv_block = nn.Sequential(
            nn.GroupNorm(groups, in_channels),
            nn.SiLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding),
        )

    def forward(self, x):
        return self.conv_block(x)


class ResnetBlock(nn.Module):
    def __init__(self, in_channels, out_channels, groups=32):
        super(ResnetBlock, self).__init__()
        self.conv_block = nn.Sequential(
            ConvBlock(in_channels, out_channels, groups=groups),
            ConvBlock(out_channels, out_channels, groups=groups),
        )
        if in_channels != out_channels:
            self.skip_conn = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
        else:
            self.skip_conn = nn.Identity()

    def forward(self, x):
        return self.conv_block(x) + self.skip_conn(x)


class AttentionBlock(nn.Module):
    def __init__(self, in_channels, out_channels, groups=32):
        super(AttentionBlock, self).__init__()
        self.q_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
        self.k_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
        self.v_conv = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
        self.out_conv = ConvBlock(out_channels, out_channels, kernel_size=1, padding=0, groups=groups)

        if in_channels != out_channels:
            self.skip_conn = ConvBlock(in_channels, out_channels, kernel_size=1, padding=0, groups=groups)
        else:
            self.skip_conn = nn.Identity()

    def forward(self, x):
        q = self.q_conv(x)
        k = self.k_conv(x)
        v = self.v_conv(x)

        attention = torch.einsum('bchw,bcHW->bhwHW', q, k)
        attention = attention / math.sqrt(q.shape[-1])
        attention = attention.softmax(dim=-1)

        out = torch.einsum('bhwHW,bcHW->bchw', attention, v)
        out = self.out_conv(out)

        return out + self.skip_conn(x)


class MiddleBlock(nn.Module):
    def __init__(self, in_channels, out_channels, groups=32):
        super(MiddleBlock, self).__init__()
        self.conv_block = nn.Sequential(
            ResnetBlock(in_channels, out_channels, groups=groups),
            AttentionBlock(out_channels, out_channels, groups=groups),
            ResnetBlock(out_channels, out_channels, groups=groups),
        )

    def forward(self, x):
        return self.conv_block(x)


class UpSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpSample, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, 1, 1)

    def forward(self, x):
        x = nn.functional.interpolate(x, scale_factor=2)
        x = self.conv(x)
        return x


class DownSample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownSample, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, 3, 2, 0)

    def forward(self, x):
        pad = (0, 1, 0, 1)
        x = F.pad(x, pad, mode='constant', value=0)
        x = self.conv(x)
        return x


class DownBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DownBlock, self).__init__()
        self.down_block = nn.Sequential(
            ResnetBlock(in_channels, out_channels),
            ResnetBlock(out_channels, out_channels),
        )

    def forward(self, x):
        return self.down_block(x)


class UpBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UpBlock, self).__init__()
        self.up_block = nn.Sequential(
            ResnetBlock(in_channels, out_channels),
            ResnetBlock(out_channels, out_channels),
        )

    def forward(self, x):
        return self.up_block(x)


class Encoder(nn.Module):
    def __init__(self, in_channels, z_channels, groups=32):
        super(Encoder, self).__init__()
        self.conv = nn.Conv2d(in_channels, 128, 3, 1, 1)
        self.res_block = self.create_resnet_block(128, 128, 2, groups=groups)
        self.res_block2 = self.create_resnet_block(128, 256, 2, groups=groups)
        self.res_block3 = self.create_resnet_block(256, 512, 2, groups=groups)
        self.down_block = DownBlock(512, 512)
        self.middle_block = MiddleBlock(512, 512, groups=groups)
        self.conv_block = ConvBlock(512, z_channels * 2, groups=groups)

    @staticmethod
    def create_resnet_block(in_channels, out_channels, num_blocks, groups=32):
        res_blocks = []
        for _ in range(num_blocks):
            res_blocks.append(ResnetBlock(in_channels, in_channels, groups=groups))
        res_blocks.append(DownSample(in_channels, out_channels))
        return nn.Sequential(*res_blocks)

    def forward(self, x):
        x = self.conv(x)
        x = self.res_block(x)
        x = self.res_block2(x)
        x = self.res_block3(x)
        x = self.down_block(x)
        x = self.middle_block(x)
        x = self.conv_block(x)
        return x


class Decoder(nn.Module):
    def __init__(self, in_channels, groups=32):
        super(Decoder, self).__init__()
        self.conv = nn.Conv2d(in_channels, 512, 3, 1, 1)
        self.middle_block = MiddleBlock(512, 512, groups=groups)
        self.resnet_block = self.create_resnet_block(512, 512, 3, groups=groups)
        self.resnet_block2 = self.create_resnet_block(512, 256, 3, groups=groups)
        self.resnet_block3 = self.create_resnet_block(256, 128, 3, groups=groups)
        self.up_block = UpBlock(128, 128)
        self.conv_block = ConvBlock(128, 3, groups=groups)

    @staticmethod
    def create_resnet_block(in_channels, out_channels, num_blocks, groups=32):
        res_blocks = []
        for _ in range(num_blocks):
            res_blocks.append(ResnetBlock(in_channels, in_channels, groups=groups))
        res_blocks.append(UpSample(in_channels, out_channels))
        return nn.Sequential(*res_blocks)

    def forward(self, x):
        x = self.conv(x)
        x = self.middle_block(x)
        x = self.resnet_block(x)
        x = self.resnet_block2(x)
        x = self.resnet_block3(x)
        x = self.up_block(x)
        x = self.conv_block(x)
        return x


class DiagonalGaussianDistribution(object):
    def __init__(self, parameters, deterministic=False):
        self.parameters = parameters
        self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
        self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
        self.deterministic = deterministic
        self.std = torch.exp(0.5 * self.logvar)
        self.var = torch.exp(self.logvar)
        if self.deterministic:
            self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)

    def sample(self):
        x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
        return x

    def kl(self, other=None):
        if self.deterministic:
            return torch.Tensor([0.])
        else:
            if other is None:
                return 0.5 * torch.sum(torch.pow(self.mean, 2)
                                       + self.var - 1.0 - self.logvar,
                                       dim=[1, 2, 3])
            else:
                return 0.5 * torch.sum(
                    torch.pow(self.mean - other.mean, 2) / other.var
                    + self.var / other.var - 1.0 - self.logvar + other.logvar,
                    dim=[1, 2, 3])

    def nll(self, sample, dims=None):
        if dims is None:
            dims = [1, 2, 3]
        if self.deterministic:
            return torch.Tensor([0.])
        log_two_pi = np.log(2.0 * np.pi)
        return 0.5 * torch.sum(
            log_two_pi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
            dim=dims)

    def mode(self):
        return self.mean


class VectorQuantizer(nn.Module):
    """带EMA更新的向量量化层"""

    def __init__(self, num_embeddings, embedding_dim, beta=0.25, decay=0.99, epsilon=1e-5, ema=False):
        super().__init__()
        self.embedding_dim = embedding_dim
        self.num_embeddings = num_embeddings
        self.beta = beta
        self.decay = decay
        self.epsilon = epsilon
        self.ema = ema

        # 码本初始化
        self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
        self.embedding.weight.data.normal_()
        # self.embedding.requires_grad_(False)
        # EMA统计量
        self.register_buffer('_ema_cluster_size', torch.zeros(num_embeddings))
        self.register_buffer('_ema_w', self.embedding.weight.data.clone())

    def forward(self, z):
        # 形状变换
        z = z.permute(0, 2, 3, 1)  # [B, D, H, W] -> [B, H, W, D]
        z_flattened = z.reshape(-1, self.embedding_dim)

        # 计算码本距离
        distances = torch.cdist(z_flattened, self.embedding.weight, p=2.0) ** 2

        # 获取最近邻编码
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.embedding(encoding_indices).view(z.shape)
        quantized = quantized.permute(0, 3, 1, 2)

        vq_loss = self.beta * F.mse_loss(quantized.detach(), z.permute(0, 3, 1, 2))

        # EMA 更新
        if self.training and self.ema:
            with torch.no_grad():
                # 更新 EMA 统计量
                encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=z.device)
                encodings.scatter_(1, encoding_indices.view(-1, 1), 1)
                updated_ema_cluster_size = self._ema_cluster_size * self.decay + (1 - self.decay) * torch.sum(encodings,
                                                                                                              0)

                # Laplace平滑
                n = torch.sum(updated_ema_cluster_size)
                updated_ema_cluster_size = ((updated_ema_cluster_size + self.epsilon)
                                            / (n + self.num_embeddings * self.epsilon) * n)

                dw = torch.matmul(encodings.t(), z_flattened)
                updated_ema_w = self._ema_w * self.decay + (1 - self.decay) * dw

                # 更新码本
                self._ema_cluster_size.data.copy_(updated_ema_cluster_size)
                self.embedding.weight.data.copy_(updated_ema_w / updated_ema_cluster_size.unsqueeze(1))
        else:
            codebook_loss = F.mse_loss(quantized, z.permute(0, 3, 1, 2).detach())
            vq_loss = vq_loss + codebook_loss

        # 直通估计
        quantized = z.permute(0, 3, 1, 2) + (quantized - z.permute(0, 3, 1, 2)).detach()

        return quantized, encoding_indices, vq_loss


class VAE(nn.Module):
    def __init__(self, in_channels, groups=32, z_channels=4, embedding_dim=4):
        super(VAE, self).__init__()
        self.scale_factor = 0.18215
        self.encoder = Encoder(in_channels, z_channels, groups=groups)
        self.decoder = Decoder(z_channels, groups=groups)
        self.quant_conv = nn.Conv2d(z_channels * 2, embedding_dim * 2, 1, 1, 0)
        self.post_quant_conv = nn.Conv2d(embedding_dim, z_channels, 1, 1, 0)

    def encode(self, x):
        h = self.encoder(x)
        moments = self.quant_conv(h)
        posterior = DiagonalGaussianDistribution(moments)
        out = posterior.sample()
        out = self.scale_factor * out
        return out

    def decode(self, z):
        z = 1. / self.scale_factor * z
        z = self.post_quant_conv(z)
        dec = self.decoder(z)
        return dec

    def forward(self, x):
        z = self.encode(x)
        dec = self.decode(z)
        return dec

    def generate(self, x):
        x = self.decoder(x)
        return x


class VQVAE(VAE):
    def __init__(self, in_channels=3, groups=8, z_channels=4, embedding_dim=4, num_embeddings=8196, beta=0.25,
                 decay=0.99, epsilon=1e-5):
        super(VQVAE, self).__init__(in_channels, groups, z_channels, embedding_dim)
        self.quantize = VectorQuantizer(num_embeddings,
                                        embedding_dim,
                                        ema=True,
                                        beta=beta,
                                        decay=decay,
                                        epsilon=epsilon)

    def forward(self, x):
        z = self.encode(x)
        quantized, _, vq_loss = self.quantize(z)
        x_recon = self.decode(quantized)
        return x_recon, vq_loss

    def calculate_balance_facter(self, perceptual_loss, gan_loss):
        last_layer = self.decoder.conv_block.conv_block[-1]
        last_layer_weight = last_layer.weight
        perceptual_loss_grads = torch.autograd.grad(perceptual_loss, last_layer_weight, retain_graph=True)[0]
        gan_loss_grads = torch.autograd.grad(gan_loss, last_layer_weight, retain_graph=True)[0]

        alpha = torch.norm(perceptual_loss_grads) / (torch.norm(gan_loss_grads) + 1e-4)
        alpha = torch.clamp(alpha, 0, 1e4).detach()
        return 0.8 * alpha

 训练脚本

import os

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from accelerate import Accelerator, DistributedDataParallelKwargs
from lpips import LPIPS
from torch.optim import Adam
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import VGG19_Weights
from tqdm import tqdm

from vae import VQVAE


# --------------------------
# 对抗组件
# --------------------------

class Discriminator(nn.Module):
    """多尺度判别器"""

    def __init__(self, in_channels=3, base_channels=4, num_layers=3):
        super().__init__()
        layers = [nn.Conv2d(in_channels, base_channels, 4, 2, 1), nn.LeakyReLU(0.2)]
        channels = base_channels
        for _ in range(1, num_layers):
            layers += [
                nn.Conv2d(channels, channels * 2, 4, 2, 1),
                nn.InstanceNorm2d(channels * 2),
                nn.LeakyReLU(0.2)
            ]
            channels *= 2
        layers += [
            nn.Conv2d(channels, channels, 4, 1, 0),
            nn.InstanceNorm2d(channels),
            nn.LeakyReLU(0.2),
            nn.Conv2d(channels, 1, 1)
        ]
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)


class PerceptualLoss(nn.Module):
    def __init__(self, layers=None):
        super(PerceptualLoss, self).__init__()
        if layers is None:
            layers = ['1', '2', '4', '7']
        self.layers = layers
        self.vgg = torchvision.models.vgg19(weights=VGG19_Weights.DEFAULT).features.eval()
        self.vgg.requires_grad_(False)
        for name, module in self.vgg.named_modules():
            if name in layers:
                module.register_forward_hook(self.forward_hook)
        self.features = []

    def forward_hook(self, module, input, output):
        self.features.append(output)

    def forward(self, x, x_recon):
        x_and_x_recon = torch.cat((x, x_recon), dim=0)
        self.features = []
        self.vgg(x_and_x_recon)
        x_and_x_recon_features = self.features

        loss = torch.tensor(0.0, device=x.device)

        for i, layer in enumerate(self.layers):
            x_feature = x_and_x_recon_features[i][:x.shape[0]]
            x_norm_factor = torch.sqrt(torch.mean(x_feature ** 2, dim=1, keepdim=True))
            x_feature = x_feature / x_norm_factor
            x_recon_feature = x_and_x_recon_features[i][x.shape[0]:]
            x_recon_norm_factor = torch.sqrt(torch.mean(x_recon_feature ** 2, dim=1, keepdim=True))
            x_recon_feature = x_recon_feature / x_recon_norm_factor
            loss += F.l1_loss(x_feature, x_recon_feature)

        return loss


# --------------------------
# 训练循环
# --------------------------

def train_vqgan(dataloader, epochs=100, mixed_precision=False, accelerate=False, disc_start=10000, rec_factor=1,
                perceptual_factor=1, learning_rate=4.5e-6, in_channels=3, groups=8, z_channels=4, embedding_dim=4,
                num_embeddings=8196, beta=0.25, decay=0.99, epsilon=1e-5):
    os.makedirs('results', exist_ok=True)
    # 初始化模型
    model = VQVAE(in_channels, groups, z_channels, embedding_dim, num_embeddings, beta, decay, epsilon)
    discriminator = Discriminator()
    # perceptual_loss_fn = PerceptualLoss()
    perceptual_loss_fn = LPIPS().eval()
    # 优化器
    opt_ae = Adam(list(model.encoder.parameters()) + list(model.decoder.parameters())
                  + list(model.quantize.parameters()), lr=learning_rate, betas=(0.5, 0.9))
    opt_disc = Adam(discriminator.parameters(), lr=learning_rate, betas=(0.5, 0.9))
    gradient_accumulation_steps = 4
    step = 0
    start_epoch = 0
    if os.path.exists("vqgan.pth"):
        state_dict = torch.load("vqgan.pth")
        step = state_dict.get("step", 0)
        start_epoch = state_dict.get("epoch", 0)
        model.load_state_dict(state_dict.get("model", {}))
        discriminator.load_state_dict(state_dict.get("discriminator", {}))
        opt_ae.load_state_dict(state_dict.get("opt_ae", {}))
        opt_disc.load_state_dict(state_dict.get("opt_disc", {}))
    ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
    if accelerate:
        accelerator = Accelerator(gradient_accumulation_steps=gradient_accumulation_steps,
                                  mixed_precision='fp16' if mixed_precision else 'no',
                                  kwargs_handlers=[ddp_kwargs])
        # 加速器
        model, discriminator, perceptual_loss_fn, opt_ae, opt_disc, dataloader = accelerator.prepare(
            model, discriminator, perceptual_loss_fn, opt_ae, opt_disc, dataloader)
        device = accelerator.device
    else:
        accelerator = None
        device = "cuda" if torch.cuda.is_available() else "cpu"
        model = model.to(device)
        discriminator = discriminator.to(device)
        perceptual_loss_fn = perceptual_loss_fn.to(device)

    for epoch in range(start_epoch, epochs):
        with tqdm(range(len(dataloader))) as pbar:
            for _, batch in zip(pbar, dataloader):
                x, _ = batch
                x = x.to(device)

                if accelerator is not None:
                    # 生成器更新
                    with accelerator.autocast():
                        disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon = train_step(
                            accelerator,
                            disc_start,
                            discriminator,
                            model,
                            perceptual_factor,
                            perceptual_loss_fn,
                            rec_factor,
                            step,
                            x)
                        opt_ae.zero_grad()
                        accelerator.backward(total_loss, retain_graph=True)
                        opt_disc.zero_grad()
                        accelerator.backward(disc_loss)
                        opt_ae.step()
                        opt_disc.step()
                else:
                    # 生成器更新
                    with torch.amp.autocast(device, enabled=mixed_precision):
                        disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon = train_step(
                            accelerator,
                            disc_start,
                            discriminator,
                            model,
                            perceptual_factor,
                            perceptual_loss_fn,
                            rec_factor,
                            step,
                            x)
                        opt_ae.zero_grad()
                        total_loss.backward(retain_graph=True)
                        opt_disc.zero_grad()
                        disc_loss.backward()
                        opt_ae.step()
                        opt_disc.step()

                pbar.set_postfix(
                    TotalLoss=np.round(total_loss.cpu().detach().numpy().item(), 5),
                    DiscLoss=np.round(disc_loss.cpu().detach().numpy().item(), 3),
                    PerceptualLoss=np.round(perceptual_loss.cpu().detach().numpy().item(), 5),
                    RecLoss=np.round(rec_loss.cpu().detach().numpy().item(), 5),
                    GenLoss=np.round(g_loss.cpu().detach().numpy().item(), 5),
                    VqLoss=np.round(vq_loss.cpu().detach().numpy().item(), 5)
                )
                pbar.update(0)
                # 日志记录
                if step % 100 == 0:
                    if accelerator:
                        if accelerator.is_main_process:
                            with torch.no_grad():
                                fake_image = x_recon[:4].permute(0, 2, 3, 1).contiguous()
                                means = torch.tensor([0.485, 0.456, 0.406]).view(1, 1, 1, 3).to(fake_image.device)
                                stds = torch.tensor([0.229, 0.224, 0.225]).view(1, 1, 1, 3).to(fake_image.device)
                                fake_image = fake_image * stds + means
                                fake_image.clamp_(0, 1)
                                fake_image = fake_image.permute(0, 3, 1, 2).contiguous()
                                real_image = x[:4].permute(0, 2, 3, 1).contiguous()
                                real_image = real_image * stds + means
                                real_image.clamp_(0, 1)
                                real_image = real_image.permute(0, 3, 1, 2).contiguous()

                                real_fake_images = torch.cat((real_image, fake_image))
                                torchvision.utils.save_image(real_fake_images,
                                                             os.path.join("results", f"{epoch}_{step}.jpg"),
                                                             nrow=4)
                    else:
                        with torch.no_grad():
                            real_fake_images = torch.cat((x[:4], x_recon.add(1).mul(0.5)[:4]))
                            torchvision.utils.save_image(real_fake_images,
                                                         os.path.join("results", f"{epoch}_{step}.jpg"),
                                                         nrow=4)
                step += 1
            if accelerate:
                if accelerate and accelerator.is_main_process:
                    unwrapped_model = accelerator.unwrap_model(model)
                    unwrapped_discriminator = accelerator.unwrap_model(discriminator)
                    # 保存模型
                    state_dict = {
                        "model": unwrapped_model.state_dict(),
                        "discriminator": unwrapped_discriminator.state_dict(),
                        "opt_ae": opt_ae.state_dict(),
                        "opt_disc": opt_disc.state_dict(),
                        "step": step,
                        "epoch": epoch
                    }

                    torch.save(state_dict, "vqgan.pth")
            else:
                # 保存模型
                state_dict = {
                    "model": model.state_dict(),
                    "discriminator": discriminator.state_dict(),
                    "opt_ae": opt_ae.state_dict(),
                    "opt_disc": opt_disc.state_dict(),
                    "step": step,
                    "epoch": epoch
                }
                torch.save(state_dict, "vqgan.pth")
    return model, discriminator, opt_ae, opt_disc


def train_step(accelerator, disc_start, discriminator, model, perceptual_factor, perceptual_loss_fn, rec_factor, step,
               x):
    x_recon, vq_loss = model(x)
    disc_real = discriminator(x)
    disc_faker = discriminator(x_recon)
    disc_factor = 0 if disc_start > step else 1
    perceptual_loss = perceptual_loss_fn(x, x_recon).mean()
    rec_loss = F.l1_loss(x_recon, x)
    perceptual_rec_loss = perceptual_factor * perceptual_loss + rec_factor * rec_loss
    perceptual_rec_loss = perceptual_rec_loss.mean()
    g_loss = -torch.mean(disc_faker)
    if accelerator:
        balance_facter = model.module.calculate_balance_facter(perceptual_rec_loss, g_loss)
    else:
        balance_facter = model.calculate_balance_facter(perceptual_rec_loss, g_loss)
    total_loss = perceptual_rec_loss + vq_loss + disc_factor * balance_facter * g_loss
    d_real_loss = F.binary_cross_entropy_with_logits(
        disc_real, torch.ones_like(disc_real))
    d_fake_loss = F.binary_cross_entropy_with_logits(
        disc_faker, torch.zeros_like(disc_faker))
    disc_loss = disc_factor * 0.5 * (d_real_loss + d_fake_loss)

    return disc_loss, g_loss, perceptual_loss, rec_loss, total_loss, vq_loss, x_recon


def get_imagenet_dataloader(batch_size=32, data_path="datasets/faces"):
    # 数据加载
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    train_dataset = ImageFolder(data_path, transform=transform)

    return DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=4)


# --------------------------
# 使用示例
# --------------------------

if __name__ == "__main__":
    # 数据加载(示例)
    train_loader = get_imagenet_dataloader(batch_size=12, data_path="faces")
    # 开始训练
    train_vqgan(train_loader, epochs=100, mixed_precision=True, accelerate=True, disc_start=10000, rec_factor=1,
                perceptual_factor=1, learning_rate=4.5e-6, in_channels=3, groups=8, z_channels=4, embedding_dim=4,
                num_embeddings=8196, beta=0.25, decay=0.99, epsilon=1e-5)


http://www.niftyadmin.cn/n/5869828.html

相关文章

unity学习55:按钮 button

目录 1 按钮 button 1.1 按钮button 其实就是一个组合体 1.2 测试按钮,在UI中添加1个按钮 1.3 按钮的属性 2 按钮的图片属性 3 按钮的变换 transition 3.1 按颜色变换 3.2 按图片精灵变换 3.3 按动画变换 4 按钮的导航 5 按钮的事件和脚本 1 按钮 …

软件安全性测试类型分享,第三方软件测试机构如何进行安全性测试?

在数字化时代,软件的安全性至关重要,因此软件产品安全性测试必不可少。软件安全性测试是指针对软件系统的漏洞、弱点及其他安全隐患进行评估和检测的过程。它旨在发现潜在的安全问题,以保护软件和用户的利益。通过系统化的测试,企…

JSON Schema 入门指南:如何定义和验证 JSON 数据结构

文章目录 一、引言二、什么是 JSON Schema?三、JSON Schema 的基本结构3.1 基本关键字3.2 对象属性3.3 数组元素3.4 字符串约束3.5 数值约束 四、示例:定义一个简单的 JSON Schema五、使用 JSON Schema 进行验证六、实战效果6.1 如何使用 七、总结 一、引…

WPF10绑定属性

目录 1. WPF属性系统1.1. CLR属性(CLR Properties)1.2. 相关属性(Related Properties)1.3. 附加属性(Attached Properties)1.4. 依赖属性(Dependency Properties) 2. 依赖属性2.1. 定…

第二十四:5.2【搭建 pinia 环境】axios 异步调用数据

第一步安装&#xff1a;npm install pinia 第二步&#xff1a;操作src/main.ts 改变里面的值的信息&#xff1a; <div class"count"><h2>当前求和为&#xff1a;{{ sum }}</h2><select v-model.number"n">  // .number 这里是…

优选算法的灵动之章:双指针专题(一)

个人主页&#xff1a;手握风云 专栏&#xff1a;算法 目录 一、双指针算法思想 二、算法题精讲 2.1. 查找总价格为目标值的两个商品 2.2. 盛最多水的容器 ​编辑 2.3. 移动零 2.4. 有效的三角形个数 一、双指针算法思想 双指针算法主要用于处理数组、链表等线性数据结构…

【Keil5教程及技巧】耗时一周精心整理万字全网最全Keil5(MDK-ARM)功能详细介绍【建议收藏-细细品尝】

&#x1f48c; 所属专栏&#xff1a;【单片机开发软件技巧】 &#x1f600; 作  者&#xff1a; 于晓超 &#x1f680; 个人简介&#xff1a;嵌入式工程师&#xff0c;专注嵌入式领域基础和实战分享 &#xff0c;欢迎咨询&#xff01; &#x1f496; 欢迎大家&#xff1…

计算机视觉(opencv-python)入门之常见图像预处理操作(待补充)

图像预处理是计算机视觉任务中的关键步骤&#xff0c;它通过对原始图像进行处理&#xff0c;以提高后续图像分析、特征提取和识别的准确性。 示例图片 常见图像预处理方法 灰度化处理 法一 #灰度化处理 #法1&#xff0c;直接读取灰度图 import cv2 gray_imagecv2.imread(te…