def build_model(args): args.to_train = 'CDG' networks = {} opts = {} if 'C' in args.to_train: networks['C'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k}) networks['C_EMA'] = GuidingNet(args.img_size, {'cont': args.sty_dim, 'disc': args.output_k}) if 'D' in args.to_train: networks['D'] = Discriminator(args.img_size, num_domains=args.output_k) if 'G' in args.to_train: networks['G'] = Generator(args.img_size, args.sty_dim, use_sn=False) networks['G_EMA'] = Generator(args.img_size, args.sty_dim, use_sn=False) if args.distributed: if args.gpu is not None: print('Distributed to', args.gpu) torch.cuda.set_device(args.gpu) args.batch_size = int(args.batch_size / args.ngpus_per_node) args.workers = int(args.workers / args.ngpus_per_node) for name, net in networks.items(): if name in ['inceptionNet']: continue net_tmp = net.cuda(args.gpu) networks[name] = torch.nn.parallel.DistributedDataParallel(net_tmp, device_ids=[args.gpu], output_device=args.gpu) else: for name, net in networks.items(): net_tmp = net.cuda() networks[name] = torch.nn.parallel.DistributedDataParallel(net_tmp) elif args.gpu is not None: torch.cuda.set_device(args.gpu) for name, net in networks.items(): networks[name] = net.cuda(args.gpu) else: for name, net in networks.items(): networks[name] = torch.nn.DataParallel(net).cuda() if 'C' in args.to_train: opts['C'] = torch.optim.Adam( networks['C'].module.parameters() if args.distributed else networks['C'].parameters(), 1e-4, weight_decay=0.001) if args.distributed: networks['C_EMA'].module.load_state_dict(networks['C'].module.state_dict()) else: networks['C_EMA'].load_state_dict(networks['C'].state_dict()) if 'D' in args.to_train: opts['D'] = torch.optim.RMSprop( networks['D'].module.parameters() if args.distributed else networks['D'].parameters(), 1e-4, weight_decay=0.0001) if 'G' in args.to_train: opts['G'] = torch.optim.RMSprop( networks['G'].module.parameters() if args.distributed else networks['G'].parameters(), 1e-4, weight_decay=0.0001) return networks, opts
def assign_adain_params(adain_params, model): # assign the adain_params to the AdaIN layers in model for m in model.modules(): if m.__class__.__name__ == "AdaIN2d": mean = adain_params[:, :m.num_features] std = adain_params[:, m.num_features:2 * m.num_features] m.bias = mean.contiguous().view(-1) m.weight = std.contiguous().view(-1) if adain_params.size(1) > 2 * m.num_features: adain_params = adain_params[:, 2 * m.num_features:] def get_num_adain_params(model): # return the number of AdaIN parameters needed by the model num_adain_params = 0 for m in model.modules(): if m.__class__.__name__ == "AdaIN2d": num_adain_params += 2 * m.num_features return num_adain_params if __name__ == '__main__': from models.guidingNet import GuidingNet C = GuidingNet(64) G = Generator(64, 128, 4) x_in = torch.randn(4, 3, 64, 64) cont = G.cnt_encoder(x_in) sty = C.moco(x_in) x_out = G.decode(cont, sty) print(x_out.shape)