Esempio n. 1
0
def train_style_transfer(args):
    if not (args.train_data and args.valid_data):
        print("must chose train_data and valid_data")
        sys.exit()

    # make dataset
    trans = transforms.ToTensor()
    train_dataset = FaceDataset(args.train_data, transform=trans)
    label_dict = train_dataset.get_label_dict()
    valid_dataset = FaceDataset(args.valid_data, transform=trans)
    valid_dataset.give_label_dict(label_dict)
    train_loader = data_utils.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=1)
    valid_loader = data_utils.DataLoader(valid_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=1)
    train_size = len(train_dataset)
    valid_size = len(valid_dataset)
    loaders = {"train": train_loader, "valid": valid_loader}
    dataset_sizes = {"train": train_size, "valid": valid_size}

    if args.gpu:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    # make network
    if args.model_type == "VAE":
        net = Autoencoder(train_dataset.label_num()).to(device)
        optimizer = optim.Adam(net.parameters(),
                               lr=args.lr,
                               weight_decay=args.weight_decay)
        best_model_wts = net.state_dict()
        best_loss = 1e10
        if args.generator_model and os.path.exists(args.generator_model):
            net.load_state_dict(torch.load(args.generator_model))

    elif args.model_type == "VAEGAN":
        generator = Autoencoder(train_dataset.label_num()).to(device)
        discriminator = Discriminator().to(device)
        classifier = Classifier(train_dataset.label_num()).to(device)
        generator_optimizer = optim.Adam(generator.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay)
        discriminator_optimizer = optim.Adam(discriminator.parameters(),
                                             lr=args.lr * 0.1,
                                             weight_decay=args.weight_decay)
        best_generator_wts = generator.state_dict()
        best_discriminator_wts = discriminator.state_dict()
        best_generator_loss = 1e10
        best_discriminator_loss = 1e10
        if args.generator_model and os.path.exists(args.generator_model):
            generator.load_state_dict(torch.load(args.generator_model))
        if args.discriminator_model and os.path.exists(
                args.discriminator_model):
            discriminator.load_state_dict(torch.load(args.discriminator_model))
        if args.classifier_model:
            classifier.load_state_dict(torch.load(args.classifier_model))
    # make loss function and optimizer
    criterion = nn.BCELoss(reduction="sum")
    classifier_criterion = nn.CrossEntropyLoss(reduction="sum")

    # initialize loss
    loss_history = {"train": [], "valid": []}

    # start training
    start_time = time.time()
    for epoch in range(args.epochs):
        print("epoch {}".format(epoch + 1))

        for phase in ["train", "valid"]:
            if phase == "train":
                if args.model_type == "VAE":
                    net.train(True)
                elif args.model_type == "VAEGAN":
                    generator.train(True)
                    discriminator.train(True)
            else:
                if args.model_type == "VAE":
                    net.train(False)
                elif args.model_type == "VAEGAN":
                    generator.train(False)
                    discriminator.train(False)

            # initialize running loss
            generator_running_loss = 0.0
            discriminator_running_loss = 0.0

            for i, data in enumerate(loaders[phase]):
                inputs, label = data

                # wrap the in valiables
                if phase == "train":
                    inputs = Variable(inputs).to(device)
                    label = Variable(label).to(device)
                    torch.set_grad_enabled(True)
                else:
                    inputs = Variable(inputs).to(device)
                    label = Variable(label).to(device)
                    torch.set_grad_enabled(False)

                # zero gradients
                if args.model_type == "VAE":
                    optimizer.zero_grad()
                    mu, var, outputs = net(inputs, label)
                    loss = loss_func(inputs, outputs, mu, var)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()
                    generator_running_loss += loss.item()

                elif args.model_type == "VAEGAN":
                    real_label = Variable(
                        torch.ones((inputs.size()[0], 1), dtype=torch.float) -
                        0.2 * (torch.rand(inputs.size()[0], 1))).to(device)
                    fake_label = Variable(
                        torch.zeros((inputs.size()[0], 1), dtype=torch.float) +
                        0.2 * (torch.rand(inputs.size()[0], 1))).to(device)
                    discriminator_optimizer.zero_grad()

                    real_pred = discriminator(inputs)
                    real_loss = criterion(real_pred, real_label)

                    random_index = np.random.randint(0,
                                                     train_dataset.label_num(),
                                                     inputs.size()[0])
                    generate_label = Variable(
                        torch.zeros_like(label)).to(device)
                    for i, index in enumerate(random_index):
                        generate_label[i][index] = 1
                    mu, var, outputs = generator(inputs, label)
                    fake_pred = discriminator(outputs.detach())
                    fake_loss = criterion(fake_pred, fake_label)

                    discriminator_loss = real_loss + fake_loss
                    if phase == "train":
                        discriminator_loss.backward()
                        discriminator_optimizer.step()

                    generator_optimizer.zero_grad()
                    #class_loss = classifier_criterion(classifier(outputs), torch.max(label, 1)[1])

                    dis_loss = criterion(discriminator(outputs), real_label)
                    gen_loss = loss_func(inputs, outputs, mu, var)
                    generator_loss = dis_loss + gen_loss
                    if phase == "train":
                        generator_loss.backward()
                        generator_optimizer.step()

                    discriminator_running_loss += discriminator_loss.item()
                    generator_running_loss += generator_loss.item()

            if args.model_type == "VAE":
                epoch_loss = generator_running_loss / dataset_sizes[
                    phase] * args.batch_size
                loss_history[phase].append(epoch_loss)

                print("{} loss {:.4f}".format(phase, epoch_loss))
                if phase == "valid" and epoch_loss < best_loss:
                    best_model_wts = net.state_dict()
                    best_loss = epoch_loss

            elif args.model_type == "VAEGAN":
                epoch_generator_loss = generator_running_loss / dataset_sizes[
                    phase] * args.batch_size
                epoch_discriminator_loss = discriminator_running_loss / dataset_sizes[
                    phase] * args.batch_size

                print("{} generator loss {:.4f}".format(
                    phase, epoch_generator_loss))
                print("{} discriminator loss {:.4f}".format(
                    phase, epoch_discriminator_loss))
                if phase == "valid" and epoch_generator_loss < best_generator_loss:
                    best_generator_wts = generator.state_dict()
                    best_generator_loss = epoch_generator_loss
                if phase == "valid" and epoch_discriminator_loss < best_discriminator_loss:
                    best_discriminator_wts = discriminator.state_dict()
                    best_generator_loss = epoch_discriminator_loss

    elapsed_time = time.time() - start_time
    print("training complete in {:.0f}s".format(elapsed_time))
    if args.model_type == "VAE":
        net.load_state_dict(best_model_wts)
        return net, label_dict

    elif args.model_type == "VAEGAN":
        generator.load_state_dict(best_generator_wts)
        discriminator.load_state_dict(best_discriminator_wts)
        return (generator, discriminator), label_dict
Esempio n. 2
0
def train_classifier(args):
    if not (args.train_data and args.valid_data):
        print("must chose train_data and valid_data")
        sys.exit()

    trans = transforms.ToTensor()
    train_dataset = FaceDataset(args.train_data, transform=trans)
    label_dict = train_dataset.get_label_dict()
    valid_dataset = FaceDataset(args.valid_data, transform=trans)
    valid_dataset.give_label_dict(label_dict)
    train_loader = data_utils.DataLoader(train_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=1)
    valid_loader = data_utils.DataLoader(valid_dataset,
                                         batch_size=args.batch_size,
                                         shuffle=True,
                                         num_workers=1)
    loaders = {"train": train_loader, "valid": valid_loader}
    dataset_sizes = {"train": len(train_dataset), "valid": len(valid_dataset)}

    if args.gpu:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    else:
        device = torch.device("cpu")

    classifier = Classifier(len(label_dict)).to(device).float()
    optimizer = optim.Adam(classifier.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    best_model_wts = classifier.state_dict()
    best_loss = 1e10
    if args.classifier_model and os.path.exists(args.classifier_model):
        classifier.load_state_dict(torch.load(args.classifier_model))
    criterion = nn.CrossEntropyLoss(reduction="sum")
    start_time = time.time()
    for epoch in range(args.epochs):
        print("epoch {}".format(epoch + 1))

        for phase in ["train", "valid"]:
            if phase == "train":
                classifier.train(True)
            else:
                classifier.train(False)

            running_loss = 0.0
            running_acc = 0
            for i, data in enumerate(loaders[phase]):
                inputs, label = data
                inputs = Variable(inputs).to(device)
                label = Variable(label).to(device)
                if phase == "train":
                    torch.set_grad_enabled(True)
                else:
                    torch.set_grad_enabled(False)

                optimizer.zero_grad()
                pred = classifier(inputs)
                reg_loss = 0
                for param in classifier.parameters():
                    reg_loss += (param * param).sum()

                loss = criterion(pred,
                                 torch.max(label,
                                           1)[1]) + 1e-9 * reg_loss * reg_loss
                if phase == "train":
                    loss.backward()
                    optimizer.step()
                running_loss += loss.item()
                running_acc += (torch.max(pred, 1)[1] == torch.max(
                    label, 1)[1]).sum().item()
            epoch_loss = running_loss / dataset_sizes[phase] * args.batch_size
            epoch_acc = running_acc / dataset_sizes[phase]
            print("{} loss {:.4f}".format(phase, epoch_loss))
            print("{} acc {:.6f}".format(phase, epoch_acc))
            if phase == "valid" and epoch_loss < best_loss:
                best_model_wts = classifier.state_dict()
                best_loss = epoch_loss

    elapsed_time = time.time() - start_time
    print("training_complete in {:.0f}".format(elapsed_time))
    classifier.load_state_dict(best_model_wts)
    return classifier, label_dict