Exemple #1
0
def num_param(args):
    TT = args.fc_tensorized
    config = json.load(open(CONFIG_DIR + args.config, 'r'))

    D = Discriminator(config, TT)
    G = Generator(config)
    G_params = sum(p.numel() for p in G.parameters() if p.requires_grad)
    D_params = sum(p.numel() for p in D.parameters() if p.requires_grad)
    params = G_params + D_params
    print("The model has:{} parameters".format(params))
Exemple #2
0
def training_process(device, nb_class_labels, model_path, result_dir, patience,
                     epochs, do_pre_train, tr_feat_path, tr_labels_path,
                     val_feat_path, val_labels_path, tr_batch_size,
                     val_batch_size, adapt_patience, adapt_epochs, d_lr,
                     tgt_lr, update_cnt, factor):
    """Implements the complete training process of the AUDASC method.

    :param device: The device that we will use.
    :type device: str
    :param nb_class_labels: The amount of labels for label classification.
    :type nb_class_labels: int
    :param model_path: The path of previously saved model (if any)
    :type model_path: str
    :param result_dir: The directory to save newly pre-trained model.
    :type result_dir: str
    :param patience: The patience for the pre-training step.
    :type patience: int
    :param epochs: The epochs for the pre-training step.
    :type epochs: int
    :param do_pre_train: Flag to indicate if we do pre-training.
    :type do_pre_train: bool
    :param tr_feat_path: The path for loading the training features.
    :type tr_feat_path: str
    :param tr_labels_path: The path for loading the training labels.
    :type tr_labels_path: str
    :param val_feat_path: The path for loading the validation features.
    :type val_feat_path: str
    :param val_labels_path: The path for loading the validation labels.
    :type val_labels_path: str
    :param tr_batch_size: The batch used for pre-training.
    :type tr_batch_size: int
    :param val_batch_size: The batch size used for validation.
    :type val_batch_size: int
    :param adapt_patience: The patience for the domain adaptation step.
    :type adapt_patience: int
    :param adapt_epochs: The epochs for the domain adaptation step.
    :type adapt_epochs: int
    :param d_lr: The learning rate for the discriminator.
    :type d_lr: float
    :param tgt_lr: The learning rate for the adapted model.
    :type tgt_lr: float
    :param update_cnt: An update controller for adversarial loss
    :type update_cnt: int
    :param factor: the coefficient used to be multiplied by classification loss.
    :type factor: int
    """

    tr_feat = device_exchange(file_io.load_pickled_features(tr_feat_path),
                              device=device)
    tr_labels = device_exchange(file_io.load_pickled_features(tr_labels_path),
                                device=device)
    val_feat = device_exchange(file_io.load_pickled_features(val_feat_path),
                               device=device)
    val_labels = device_exchange(
        file_io.load_pickled_features(val_labels_path), device=device)

    loss_func = functional.cross_entropy

    non_adapted_cnn = Model().to(device)
    label_classifier = LabelClassifier(nb_class_labels).to(device)

    if not path.exists(result_dir):
        makedirs(result_dir)

    if do_pre_train:
        state_dict_path = result_dir

        printing.info_msg('Pre-training step')

        optimizer_source = torch.optim.Adam(
            list(non_adapted_cnn.parameters()) +
            list(label_classifier.parameters()),
            lr=1e-4)

        pre_training.pre_training(model=non_adapted_cnn,
                                  label_classifier=label_classifier,
                                  optimizer=optimizer_source,
                                  tr_batch_size=tr_batch_size,
                                  val_batch_size=val_batch_size,
                                  tr_feat=tr_feat['A'],
                                  tr_labels=tr_labels['A'],
                                  val_feat=val_feat['A'],
                                  val_labels=val_labels['A'],
                                  epochs=epochs,
                                  criterion=loss_func,
                                  patience=patience,
                                  result_dir=state_dict_path)

        del optimizer_source

    else:
        printing.info_msg('Loading a pre-trained non-adapted model')
        state_dict_path = model_path

    if not path.exists(state_dict_path):
        raise ValueError(
            'The path for loading the pre trained model does not exist!')

    non_adapted_cnn.load_state_dict(
        torch.load(path.join(state_dict_path, 'non_adapted_cnn.pytorch')))
    label_classifier.load_state_dict(
        torch.load(path.join(state_dict_path, 'label_classifier.pytorch')))

    printing.info_msg('Training the Adversarial Adaptation Model')

    target_cnn = Model().to(device)
    target_cnn.load_state_dict(non_adapted_cnn.state_dict())
    discriminator = Discriminator(2).to(device)

    target_model_opt = torch.optim.Adam(target_cnn.parameters(), lr=tgt_lr)
    discriminator_opt = torch.optim.Adam(discriminator.parameters(), lr=d_lr)

    domain_adaptation.domain_adaptation(
        non_adapted_cnn, target_cnn, label_classifier, discriminator,
        target_model_opt, discriminator_opt, loss_func, loss_func, loss_func,
        tr_feat, tr_labels, val_feat, val_labels, adapt_epochs, update_cnt,
        result_dir, adapt_patience, device, factor)
Exemple #3
0
def train(args):
    pre_trained = args.pre_trained
    PATH = args.path_results
    lrD = args.lrD
    lrG = args.lrG
    epochs = args.epochs
    batch_size = args.batch
    device = args.device
    save_every = args.save_every
    data = args.data
    config = json.load(open(CONFIG_DIR + args.config, 'r'))
    TT = args.fc_tensorized

    print(TT)

    # Create directory for results
    if not os.path.isdir(PATH):
        os.mkdir(PATH)
    # Create directory for specific run
    if TT:
        PATH = PATH + "/{}_ttfc".format(config["id"])
    else:
        PATH = PATH + "/{}".format(config["id"])
    if not os.path.isdir(PATH):
        os.mkdir(PATH)
    if not os.path.isdir(PATH + '/Random_results'):
        os.mkdir(PATH + '/Random_results')
    if not os.path.isdir(PATH + '/Fixed_results'):
        os.mkdir(PATH + '/Fixed_results')

    print("### Loading data ###")
    train_loader = load_dataset(data, batch_size, is_train=True)
    print("### Loaded data ###")

    print("### Create models ###")
    D = Discriminator(config, TT).to(device)
    G = Generator(config).to(device)
    model_parameters = filter(lambda p: p.requires_grad, D.parameters())
    params = sum([np.prod(p.size()) for p in model_parameters])
    model_parameters = filter(lambda p: p.requires_grad, G.parameters())
    params += sum([np.prod(p.size()) for p in model_parameters])
    print("The model has:{} parameters".format(params))
    if pre_trained:
        D.encoder.load()
        G.decoder.load()

    G_optimizer = optim.Adam(G.parameters(), lr=lrG, betas=(0.5, 0.999))
    D_optimizer = optim.Adam(D.parameters(), lr=lrD, betas=(0.5, 0.999))

    train_hist = {'D_losses': [], 'G_losses': [], 'G_fix_losses': []}

    BCE_loss = nn.BCELoss()
    fixed_z_ = torch.randn((5 * 5, 100)).to(device)  # fixed noise
    for epoch in range(epochs):
        if epoch == 1 or epoch % save_every == 0:
            D_test = copy.deepcopy(D)
        D_losses = []
        G_losses = []
        G_fix_losses = []
        for x, _ in train_loader:
            x = x.to(device)
            D_loss = D.train_step(x, G, D_optimizer, BCE_loss, device)
            G_loss = G.train_step(D, batch_size, G_optimizer, BCE_loss, device)
            # G_fix_loss = G.evaluate(
            #     D_test,
            #     batch_size,
            #     BCE_loss,
            #     device
            # )

            D_losses.append(D_loss)
            G_losses.append(G_loss)
            # G_fix_losses.append(G_fix_loss)

        meanDloss = torch.mean(torch.FloatTensor(D_losses))
        meanGloss = torch.mean(torch.FloatTensor(G_losses))
        meanGFloss = torch.mean(torch.FloatTensor(G_fix_losses))
        train_hist['D_losses'].append(meanDloss)
        train_hist['G_losses'].append(meanGloss)
        train_hist['G_fix_losses'].append(meanGFloss)
        print(
            "[{:d}/{:d}]: loss_d: {:.3f}, loss_g: {:.3f}, loss_g_fix: {:.3f}".
            format(epoch + 1, epochs, meanDloss, meanGloss, meanGFloss))
        p = PATH + '/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
        fixed_p = PATH + '/Fixed_results/MNIST_DCGAN_' + str(epoch +
                                                             1) + '.png'
        z_ = torch.randn((5 * 5, 100)).to(device)
        show_result(G,
                    100,
                    fixed_z_,
                    z_, (epoch + 1),
                    save=True,
                    path=p,
                    isFix=False)
        show_result(G,
                    100,
                    fixed_z_,
                    z_, (epoch + 1),
                    save=True,
                    path=fixed_p,
                    isFix=True)

    print("Training complete. Saving.")
    save_models(D, G, PATH, train_hist, epochs)
    show_train_hist(train_hist,
                    save=True,
                    path=PATH + '/MNIST_DCGAN_train_hist.png')
    save_gif(PATH, epochs)

    return D, G