def train_style_transfer(args): if not (args.train_data and args.valid_data): print("must chose train_data and valid_data") sys.exit() # make dataset trans = transforms.ToTensor() train_dataset = FaceDataset(args.train_data, transform=trans) label_dict = train_dataset.get_label_dict() valid_dataset = FaceDataset(args.valid_data, transform=trans) valid_dataset.give_label_dict(label_dict) train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) valid_loader = data_utils.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) train_size = len(train_dataset) valid_size = len(valid_dataset) loaders = {"train": train_loader, "valid": valid_loader} dataset_sizes = {"train": train_size, "valid": valid_size} if args.gpu: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") # make network if args.model_type == "VAE": net = Autoencoder(train_dataset.label_num()).to(device) optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.weight_decay) best_model_wts = net.state_dict() best_loss = 1e10 if args.generator_model and os.path.exists(args.generator_model): net.load_state_dict(torch.load(args.generator_model)) elif args.model_type == "VAEGAN": generator = Autoencoder(train_dataset.label_num()).to(device) discriminator = Discriminator().to(device) classifier = Classifier(train_dataset.label_num()).to(device) generator_optimizer = optim.Adam(generator.parameters(), lr=args.lr, weight_decay=args.weight_decay) discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=args.lr * 0.1, weight_decay=args.weight_decay) best_generator_wts = generator.state_dict() best_discriminator_wts = discriminator.state_dict() best_generator_loss = 1e10 best_discriminator_loss = 1e10 if args.generator_model and os.path.exists(args.generator_model): generator.load_state_dict(torch.load(args.generator_model)) if args.discriminator_model and os.path.exists( args.discriminator_model): discriminator.load_state_dict(torch.load(args.discriminator_model)) if args.classifier_model: classifier.load_state_dict(torch.load(args.classifier_model)) # make loss function and optimizer criterion = nn.BCELoss(reduction="sum") classifier_criterion = nn.CrossEntropyLoss(reduction="sum") # initialize loss loss_history = {"train": [], "valid": []} # start training start_time = time.time() for epoch in range(args.epochs): print("epoch {}".format(epoch + 1)) for phase in ["train", "valid"]: if phase == "train": if args.model_type == "VAE": net.train(True) elif args.model_type == "VAEGAN": generator.train(True) discriminator.train(True) else: if args.model_type == "VAE": net.train(False) elif args.model_type == "VAEGAN": generator.train(False) discriminator.train(False) # initialize running loss generator_running_loss = 0.0 discriminator_running_loss = 0.0 for i, data in enumerate(loaders[phase]): inputs, label = data # wrap the in valiables if phase == "train": inputs = Variable(inputs).to(device) label = Variable(label).to(device) torch.set_grad_enabled(True) else: inputs = Variable(inputs).to(device) label = Variable(label).to(device) torch.set_grad_enabled(False) # zero gradients if args.model_type == "VAE": optimizer.zero_grad() mu, var, outputs = net(inputs, label) loss = loss_func(inputs, outputs, mu, var) if phase == "train": loss.backward() optimizer.step() generator_running_loss += loss.item() elif args.model_type == "VAEGAN": real_label = Variable( torch.ones((inputs.size()[0], 1), dtype=torch.float) - 0.2 * (torch.rand(inputs.size()[0], 1))).to(device) fake_label = Variable( torch.zeros((inputs.size()[0], 1), dtype=torch.float) + 0.2 * (torch.rand(inputs.size()[0], 1))).to(device) discriminator_optimizer.zero_grad() real_pred = discriminator(inputs) real_loss = criterion(real_pred, real_label) random_index = np.random.randint(0, train_dataset.label_num(), inputs.size()[0]) generate_label = Variable( torch.zeros_like(label)).to(device) for i, index in enumerate(random_index): generate_label[i][index] = 1 mu, var, outputs = generator(inputs, label) fake_pred = discriminator(outputs.detach()) fake_loss = criterion(fake_pred, fake_label) discriminator_loss = real_loss + fake_loss if phase == "train": discriminator_loss.backward() discriminator_optimizer.step() generator_optimizer.zero_grad() #class_loss = classifier_criterion(classifier(outputs), torch.max(label, 1)[1]) dis_loss = criterion(discriminator(outputs), real_label) gen_loss = loss_func(inputs, outputs, mu, var) generator_loss = dis_loss + gen_loss if phase == "train": generator_loss.backward() generator_optimizer.step() discriminator_running_loss += discriminator_loss.item() generator_running_loss += generator_loss.item() if args.model_type == "VAE": epoch_loss = generator_running_loss / dataset_sizes[ phase] * args.batch_size loss_history[phase].append(epoch_loss) print("{} loss {:.4f}".format(phase, epoch_loss)) if phase == "valid" and epoch_loss < best_loss: best_model_wts = net.state_dict() best_loss = epoch_loss elif args.model_type == "VAEGAN": epoch_generator_loss = generator_running_loss / dataset_sizes[ phase] * args.batch_size epoch_discriminator_loss = discriminator_running_loss / dataset_sizes[ phase] * args.batch_size print("{} generator loss {:.4f}".format( phase, epoch_generator_loss)) print("{} discriminator loss {:.4f}".format( phase, epoch_discriminator_loss)) if phase == "valid" and epoch_generator_loss < best_generator_loss: best_generator_wts = generator.state_dict() best_generator_loss = epoch_generator_loss if phase == "valid" and epoch_discriminator_loss < best_discriminator_loss: best_discriminator_wts = discriminator.state_dict() best_generator_loss = epoch_discriminator_loss elapsed_time = time.time() - start_time print("training complete in {:.0f}s".format(elapsed_time)) if args.model_type == "VAE": net.load_state_dict(best_model_wts) return net, label_dict elif args.model_type == "VAEGAN": generator.load_state_dict(best_generator_wts) discriminator.load_state_dict(best_discriminator_wts) return (generator, discriminator), label_dict
def train_classifier(args): if not (args.train_data and args.valid_data): print("must chose train_data and valid_data") sys.exit() trans = transforms.ToTensor() train_dataset = FaceDataset(args.train_data, transform=trans) label_dict = train_dataset.get_label_dict() valid_dataset = FaceDataset(args.valid_data, transform=trans) valid_dataset.give_label_dict(label_dict) train_loader = data_utils.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) valid_loader = data_utils.DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=True, num_workers=1) loaders = {"train": train_loader, "valid": valid_loader} dataset_sizes = {"train": len(train_dataset), "valid": len(valid_dataset)} if args.gpu: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device("cpu") classifier = Classifier(len(label_dict)).to(device).float() optimizer = optim.Adam(classifier.parameters(), lr=args.lr, weight_decay=args.weight_decay) best_model_wts = classifier.state_dict() best_loss = 1e10 if args.classifier_model and os.path.exists(args.classifier_model): classifier.load_state_dict(torch.load(args.classifier_model)) criterion = nn.CrossEntropyLoss(reduction="sum") start_time = time.time() for epoch in range(args.epochs): print("epoch {}".format(epoch + 1)) for phase in ["train", "valid"]: if phase == "train": classifier.train(True) else: classifier.train(False) running_loss = 0.0 running_acc = 0 for i, data in enumerate(loaders[phase]): inputs, label = data inputs = Variable(inputs).to(device) label = Variable(label).to(device) if phase == "train": torch.set_grad_enabled(True) else: torch.set_grad_enabled(False) optimizer.zero_grad() pred = classifier(inputs) reg_loss = 0 for param in classifier.parameters(): reg_loss += (param * param).sum() loss = criterion(pred, torch.max(label, 1)[1]) + 1e-9 * reg_loss * reg_loss if phase == "train": loss.backward() optimizer.step() running_loss += loss.item() running_acc += (torch.max(pred, 1)[1] == torch.max( label, 1)[1]).sum().item() epoch_loss = running_loss / dataset_sizes[phase] * args.batch_size epoch_acc = running_acc / dataset_sizes[phase] print("{} loss {:.4f}".format(phase, epoch_loss)) print("{} acc {:.6f}".format(phase, epoch_acc)) if phase == "valid" and epoch_loss < best_loss: best_model_wts = classifier.state_dict() best_loss = epoch_loss elapsed_time = time.time() - start_time print("training_complete in {:.0f}".format(elapsed_time)) classifier.load_state_dict(best_model_wts) return classifier, label_dict