コード例 #1
0
def source_only(encoder, classifier, source_train_loader, target_train_loader,
                save_name):
    print("Source-only training")
    classifier_criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(list(encoder.parameters()) +
                          list(classifier.parameters()),
                          lr=0.01,
                          momentum=0.9)

    for epoch in range(params.epochs):
        print('Epoch : {}'.format(epoch))
        set_model_mode('train', [encoder, classifier])

        start_steps = epoch * len(source_train_loader)
        total_steps = params.epochs * len(target_train_loader)

        for batch_idx, (source_data, target_data) in enumerate(
                zip(source_train_loader, target_train_loader)):
            source_image, source_label = source_data
            p = float(batch_idx + start_steps) / total_steps

            source_image = torch.cat(
                (source_image, source_image, source_image),
                1)  # MNIST convert to 3 channel
            source_image, source_label = source_image.cuda(
            ), source_label.cuda()  # 32

            optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p)
            optimizer.zero_grad()

            source_feature = encoder(source_image)

            # Classification loss
            class_pred = classifier(source_feature)
            class_loss = classifier_criterion(class_pred, source_label)

            class_loss.backward()
            optimizer.step()
            if (batch_idx + 1) % 50 == 0:
                print('[{}/{} ({:.0f}%)]\tClass Loss: {:.6f}'.format(
                    batch_idx * len(source_image),
                    len(source_train_loader.dataset),
                    100. * batch_idx / len(source_train_loader),
                    class_loss.item()))

        if (epoch + 1) % 10 == 0:
            test.tester(encoder,
                        classifier,
                        None,
                        source_test_loader,
                        target_test_loader,
                        training_mode='source_only')
    save_model(encoder, classifier, None, 'source', save_name)
    visualize(encoder, 'source', save_name)
コード例 #2
0
def forward_pass(model_dict, data, label, batch_indexes, phase, info):
    # prepare data in appropriate way
    model_dict = set_model_mode(model_dict, phase)
    target = data = data.to(info['device'])
    # input is bt 0 and 1
    bs, c, h, w = data.shape
    z, u_q = model_dict['acn_model'](data)
    u_q_flat = u_q.view(bs, info['code_length'])
    if phase == 'train':
        # check that we are getting what we think we are getting from the replay
        # buffer
        assert batch_indexes.max() < model_dict['prior_model'].codes.shape[0]
        model_dict['prior_model'].update_codebook(batch_indexes,
                                                  u_q_flat.detach())

    u_p, s_p = model_dict['prior_model'](u_q_flat)
    u_p = u_p.view(bs, model_dict['acn_model'].bottleneck_channels,
                   model_dict['acn_model'].eo, model_dict['acn_model'].eo)
    s_p = s_p.view(bs, model_dict['acn_model'].bottleneck_channels,
                   model_dict['acn_model'].eo, model_dict['acn_model'].eo)
    if info['vq_decoder']:
        rec_dml, z_e_x, z_q_x, latents = model_dict['acn_model'].decode(z)
        return model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents
    else:
        rec_dml = model_dict['acn_model'].decode(z)
        return model_dict, data, target, rec_dml, u_q, u_p, s_p
コード例 #3
0
def save_latents(l_filepath, model_dict, data_dict, info):
    # always be in eval mode - so we dont swap neighbors
    all_vq_latents = []
    with torch.no_grad():
        for phase in ['train', 'valid']:
            if not os.path.exists(l_filepath + '_%s.npz' % phase):
                model_dict = set_model_mode(model_dict, 'valid')
                data_loader = data_dict[phase]
                for idx, (data, label, batch_index) in enumerate(data_loader):
                    fp_out = forward_pass(model_dict, data, label, batch_index,
                                          phase, info)
                    if info['vq_decoder']:
                        model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out
                    else:
                        model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out
                    bs, c, h, w = data.shape
                    u_q_flat = u_q.view(bs, info['code_length'])
                    n_neighbors = info['num_k']
                    neighbor_distances, neighbor_indexes = model_dict[
                        'prior_model'].kneighbors(u_q_flat,
                                                  n_neighbors=n_neighbors)
                    if not idx:
                        all_indexes = batch_index.cpu().numpy()
                        all_labels = label.cpu().numpy()
                        all_acn_uq = u_q.cpu().numpy()
                        all_neighbors = neighbor_indexes.cpu().numpy()
                        all_neighbor_distances = neighbor_distances.cpu(
                        ).numpy()
                        if info['vq_decoder']:
                            all_vq_latents = latents.cpu().numpy()
                    else:
                        all_indexes = np.append(all_indexes,
                                                batch_index.cpu().numpy())
                        all_labels = np.append(all_labels, label.cpu().numpy())
                        all_acn_uq = np.vstack((all_acn_uq, u_q.cpu().numpy()))
                        all_neighbors = np.vstack(
                            (all_neighbors, neighbor_indexes.cpu().numpy()))
                        all_neighbor_distances = np.vstack(
                            (all_neighbor_distances,
                             neighbor_distances.cpu().numpy()))
                        if info['vq_decoder']:
                            all_vq_latents = np.vstack(
                                (all_vq_latents, latents.cpu().numpy()))
                    print('finished save latents', all_neighbors.shape[0])
                    np.savez(l_filepath + '_' + phase,
                             index=all_indexes,
                             labels=all_labels,
                             acn_uq=all_acn_uq,
                             neighbor_train_indexes=all_neighbors,
                             neighbor_distances=all_neighbor_distances,
                             vq_latents=all_vq_latents)

    train_results = np.load(l_filepath + '_train.npz')
    valid_results = np.load(l_filepath + '_valid.npz')
    return train_results, valid_results
コード例 #4
0
def tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode):
    print("Model test ...")

    encoder.cuda()
    classifier.cuda()
    set_model_mode('eval', [encoder, classifier])
    
    if training_mode == 'dann':
        discriminator.cuda()
        set_model_mode('eval', [discriminator])
        domain_correct = 0

    source_correct = 0
    target_correct = 0

    for batch_idx, (source_data, target_data) in enumerate(zip(source_test_loader, target_test_loader)):
        p = float(batch_idx) / len(source_test_loader)
        alpha = 2. / (1. + np.exp(-10 * p)) - 1

        # 1. Source input -> Source Classification
        source_image, source_label = source_data
        source_image, source_label = source_image.cuda(), source_label.cuda()
        source_image = torch.cat((source_image, source_image, source_image), 1)  # MNIST convert to 3 channel
        source_feature = encoder(source_image)
        source_output = classifier(source_feature)
        source_pred = source_output.data.max(1, keepdim=True)[1]
        source_correct += source_pred.eq(source_label.data.view_as(source_pred)).cpu().sum()

        # 2. Target input -> Target Classification
        target_image, target_label = target_data
        target_image, target_label = target_image.cuda(), target_label.cuda()
        target_feature = encoder(target_image)
        target_output = classifier(target_feature)
        target_pred = target_output.data.max(1, keepdim=True)[1]
        target_correct += target_pred.eq(target_label.data.view_as(target_pred)).cpu().sum()

        if training_mode == 'dann':
            # 3. Combined input -> Domain Classificaion
            combined_image = torch.cat((source_image, target_image), 0)  # 64 = (S:32 + T:32)
            domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor)
            domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor)
            domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda()
            domain_feature = encoder(combined_image)
            domain_output = discriminator(domain_feature, alpha)
            domain_pred = domain_output.data.max(1, keepdim=True)[1]
            domain_correct += domain_pred.eq(domain_combined_label.data.view_as(domain_pred)).cpu().sum()

    if training_mode == 'dann':
        print("Test Results on DANN :")
        print('\nSource Accuracy: {}/{} ({:.2f}%)\n'
              'Target Accuracy: {}/{} ({:.2f}%)\n'
              'Domain Accuracy: {}/{} ({:.2f}%)\n'.
            format(
            source_correct, len(source_test_loader.dataset), 100. * source_correct.item() / len(source_test_loader.dataset),
            target_correct, len(target_test_loader.dataset), 100. * target_correct.item() / len(target_test_loader.dataset),
            domain_correct, len(source_test_loader.dataset) + len(target_test_loader.dataset), 100. * domain_correct.item() / (len(source_test_loader.dataset) + len(target_test_loader.dataset))
        ))
    else:
        print("Test results on source_only :")
        print('\nSource Accuracy: {}/{} ({:.2f}%)\n'
              'Target Accuracy: {}/{} ({:.2f}%)\n'.format(
            source_correct, len(source_test_loader.dataset), 100. * source_correct.item() / len(source_test_loader.dataset),
            target_correct, len(target_test_loader.dataset), 100. * target_correct.item() / len(target_test_loader.dataset)))
コード例 #5
0
def call_plot(model_dict, data_dict, info, sample, tsne, pca):
    from utils import tsne_plot
    from utils import pca_plot
    from sklearn.cluster import KMeans
    # always be in eval mode - so we dont swap neighbors
    model_dict = set_model_mode(model_dict, 'valid')
    srandom_state = np.random.RandomState(1234)
    with torch.no_grad():
        for phase in ['train', 'valid']:
            batch_index = srandom_state.randint(0,
                                                len(data_dict[phase].dataset),
                                                info['batch_size'])
            print(batch_index)
            data = torch.stack([
                data_dict[phase].dataset.indexed_dataset[index][0]
                for index in batch_index
            ])
            label = torch.stack([
                data_dict[phase].dataset.indexed_dataset[index][1]
                for index in batch_index
            ])
            batch_index = torch.LongTensor(batch_index)
            data = torch.FloatTensor(data)
            fp_out = forward_pass(model_dict, data, label, batch_index,
                                  'valid', info)
            if info['vq_decoder']:
                model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out
            else:
                model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out

            bs, c, h, w = data.shape
            rec_yhat = sample_from_discretized_mix_logistic(
                rec_dml,
                info['nr_logistic_mix'],
                only_mean=info['sample_mean'],
                sampling_temperature=info['sampling_temperature'])
            data = data.detach().cpu().numpy()
            rec = rec_yhat.detach().cpu().numpy()
            u_q_flat = u_q.view(bs, info['code_length'])
            # choose limited number to plot
            n = min([20, bs])
            n_neighbors = args.num_k
            if sample:
                all_neighbor_distances, all_neighbor_indexes = model_dict[
                    'prior_model'].kneighbors(u_q_flat,
                                              n_neighbors=n_neighbors)
                all_neighbor_indexes = all_neighbor_indexes.cpu().numpy()
                all_neighbor_distances = all_neighbor_distances.cpu().numpy()

                n_cols = 2 + n_neighbors
                tbatch_index = batch_index.cpu().numpy()
                np_label = label.cpu().numpy()
                for i in np.arange(0, n):
                    # plot each base image
                    plt_path = info['model_loadpath'].replace(
                        '.pt', '_batch_rec_neighbors_%s_%06d_plt.png' %
                        (phase, tbatch_index[i]))
                    # bi 5136
                    neighbor_indexes = all_neighbor_indexes[i]
                    code = u_q[i].view(
                        (1, model_dict['acn_model'].bottleneck_channels,
                         model_dict['acn_model'].eo,
                         model_dict['acn_model'].eo)).cpu().numpy()
                    f, ax = plt.subplots(4, n_cols)
                    ax[0,
                       0].set_title('L%sI%s' % (np_label[i], tbatch_index[i]))
                    ax[0, 0].set_ylabel('true')
                    ax[0, 0].matshow(data[i, 0])
                    ax[1, 0].set_ylabel('rec')
                    ax[1, 0].matshow(rec[i, 0])
                    ax[2, 0].matshow(code[0, 0])
                    ax[3, 0].matshow(code[0, 1])

                    neighbor_data = torch.stack([
                        data_dict['train'].dataset.indexed_dataset[index][0]
                        for index in neighbor_indexes
                    ])
                    neighbor_label = torch.stack([
                        data_dict['train'].dataset.indexed_dataset[index][1]
                        for index in neighbor_indexes
                    ])
                    # u_q_flat
                    neighbor_codes_flat = model_dict['prior_model'].codes[
                        neighbor_indexes]
                    neighbor_codes = neighbor_codes_flat.view(
                        n_neighbors,
                        model_dict['acn_model'].bottleneck_channels,
                        model_dict['acn_model'].eo, model_dict['acn_model'].eo)
                    if info['vq_decoder']:
                        neighbor_rec_dml, _, _, _ = model_dict[
                            'acn_model'].decode(
                                neighbor_codes.to(info['device']))
                    else:
                        neighbor_rec_dml = model_dict['acn_model'].decode(
                            neighbor_codes.to(info['device']))
                    neighbor_data = neighbor_data.cpu().numpy()
                    neighbor_rec_yhat = sample_from_discretized_mix_logistic(
                        neighbor_rec_dml,
                        info['nr_logistic_mix'],
                        only_mean=info['sample_mean'],
                        sampling_temperature=info['sampling_temperature']).cpu(
                        ).numpy()
                    for ni in range(n_neighbors):
                        nindex = all_neighbor_indexes[i, ni].item()
                        nlabel = neighbor_label[ni].cpu().numpy()
                        ncode = neighbor_codes[ni].cpu().numpy()
                        ax[0, ni + 2].set_title('L%sI%s' % (nlabel, nindex))
                        ax[0, ni + 2].matshow(neighbor_data[ni, 0])
                        ax[1, ni + 2].matshow(neighbor_rec_yhat[ni, 0])
                        ax[2, ni + 2].matshow(ncode[0])
                        ax[3, ni + 2].matshow(ncode[1])
                    ax[2, 0].set_ylabel('lc0')
                    ax[3, 0].set_ylabel('lc1')
                    [ax[xx, 0].set_xticks([]) for xx in range(4)]
                    [ax[xx, 0].set_yticks([]) for xx in range(4)]
                    for xx in range(4):
                        [ax[xx, col].axis('off') for col in range(1, n_cols)]
                    plt.subplots_adjust(wspace=0, hspace=0)
                    plt.tight_layout()
                    print('plotting', plt_path)
                    plt.savefig(plt_path)
                    plt.close()
            X = u_q_flat.cpu().numpy()
            #km = KMeans(n_clusters=10)
            #y = km.fit_predict(X)
            # color points based on clustering, label, or index
            color = label.cpu().numpy()  #y #batch_indexes
            if tsne:
                param_name = '_tsne_%s_P%s.html' % (phase, info['perplexity'])
                html_path = info['model_loadpath'].replace('.pt', param_name)
                if not os.path.exists(html_path):
                    tsne_plot(X=X,
                              images=data[:, 0],
                              color=color,
                              perplexity=info['perplexity'],
                              html_out_path=html_path,
                              serve=False)
            if pca:
                param_name = '_pca_%s.html' % (phase)
                html_path = info['model_loadpath'].replace('.pt', param_name)
                if not os.path.exists(html_path):
                    pca_plot(X=X,
                             images=data[:, 0],
                             color=color,
                             html_out_path=html_path,
                             serve=False)
コード例 #6
0
def run(train_cnt, model_dict, data_dict, phase, info):
    st = time.time()
    loss_dict = {
        'running': 0,
        'kl': 0,
        'rec_%s' % info['rec_loss_type']: 0,
        'loss': 0,
    }
    if info['vq_decoder']:
        loss_dict['vq'] = 0
        loss_dict['commit'] = 0

    dataset = data_dict[phase]
    num_batches = len(dataset) // info['batch_size']
    print(phase, 'num batches', num_batches)
    set_model_mode(model_dict, phase)
    torch.set_grad_enabled(phase == 'train')
    batch_cnt = 0

    data_loader = data_dict[phase]
    num_batches = len(data_loader)
    for idx, (data, label, batch_index) in enumerate(data_loader):
        for key in model_dict.keys():
            model_dict[key].zero_grad()
        fp_out = forward_pass(model_dict, data, label, batch_index, phase,
                              info)
        if info['vq_decoder']:
            model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out
        else:
            model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out
        bs, c, h, w = data.shape
        if batch_cnt == 0:
            log_ones = torch.zeros(bs, info['code_length']).to(info['device'])
        if bs != log_ones.shape[0]:
            log_ones = torch.zeros(bs, info['code_length']).to(info['device'])
        kl = kl_loss_function(u_q.view(bs, info['code_length']),
                              log_ones,
                              u_p.view(bs, info['code_length']),
                              s_p.view(bs, info['code_length']),
                              reduction=info['reduction'])
        rec_loss = discretized_mix_logistic_loss(
            rec_dml,
            target,
            nr_mix=info['nr_logistic_mix'],
            reduction=info['reduction'])

        if info['vq_decoder']:
            vq_loss = F.mse_loss(z_q_x,
                                 z_e_x.detach(),
                                 reduction=info['reduction'])
            commit_loss = F.mse_loss(z_e_x,
                                     z_q_x.detach(),
                                     reduction=info['reduction'])
            commit_loss *= info['vq_commitment_beta']
            loss_dict['vq'] += vq_loss.detach().cpu().item()
            loss_dict['commit'] += commit_loss.detach().cpu().item()
            loss = kl + rec_loss + commit_loss + vq_loss
        else:
            loss = kl + rec_loss

        loss_dict['running'] += bs
        loss_dict['rec_%s' %
                  info['rec_loss_type']] += rec_loss.detach().cpu().item()
        loss_dict['loss'] += loss.detach().cpu().item()
        loss_dict['kl'] += kl.detach().cpu().item()
        loss_dict['running'] += bs
        loss_dict['loss'] += loss.detach().cpu().item()
        loss_dict['kl'] += kl.detach().cpu().item()
        loss_dict['rec_%s' %
                  info['rec_loss_type']] += rec_loss.detach().cpu().item()
        if phase == 'train':
            model_dict = clip_parameters(model_dict)
            loss.backward()
            model_dict['opt'].step()
            train_cnt += bs
        if batch_cnt == num_batches - 1:
            # store example near end for plotting
            rec_yhat = sample_from_discretized_mix_logistic(
                rec_dml,
                info['nr_logistic_mix'],
                only_mean=info['sample_mean'],
                sampling_temperature=info['sampling_temperature'])
            example = {
                'target': data.detach().cpu().numpy(),
                'rec': rec_yhat.detach().cpu().numpy(),
            }
        if not batch_cnt % 100:
            print(train_cnt, batch_cnt, account_losses(loss_dict))
            print(phase, 'cuda', torch.cuda.memory_allocated(device=None))
        batch_cnt += 1

    loss_avg = account_losses(loss_dict)
    torch.cuda.empty_cache()
    print("finished %s after %s secs at cnt %s" % (
        phase,
        time.time() - st,
        train_cnt,
    ))
    del data
    del target
    return loss_avg, example
コード例 #7
0
def dann(encoder, classifier, discriminator, source_train_loader,
         target_train_loader, save_name):
    print("DANN training")

    classifier_criterion = nn.CrossEntropyLoss().cuda()
    discriminator_criterion = nn.CrossEntropyLoss().cuda()

    optimizer = optim.SGD(list(encoder.parameters()) +
                          list(classifier.parameters()) +
                          list(discriminator.parameters()),
                          lr=0.01,
                          momentum=0.9)

    for epoch in range(params.epochs):
        print('Epoch : {}'.format(epoch))
        set_model_mode('train', [encoder, classifier, discriminator])

        start_steps = epoch * len(source_train_loader)
        total_steps = params.epochs * len(target_train_loader)

        for batch_idx, (source_data, target_data) in enumerate(
                zip(source_train_loader, target_train_loader)):

            source_image, source_label = source_data
            target_image, target_label = target_data

            p = float(batch_idx + start_steps) / total_steps
            alpha = 2. / (1. + np.exp(-10 * p)) - 1

            source_image = torch.cat(
                (source_image, source_image, source_image), 1)

            source_image, source_label = source_image.cuda(
            ), source_label.cuda()
            target_image, target_label = target_image.cuda(
            ), target_label.cuda()
            combined_image = torch.cat((source_image, target_image), 0)

            optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p)
            optimizer.zero_grad()

            combined_feature = encoder(combined_image)
            source_feature = encoder(source_image)

            # 1.Classification loss
            class_pred = classifier(source_feature)
            class_loss = classifier_criterion(class_pred, source_label)

            # 2. Domain loss
            domain_pred = discriminator(combined_feature, alpha)

            domain_source_labels = torch.zeros(source_label.shape[0]).type(
                torch.LongTensor)
            domain_target_labels = torch.ones(target_label.shape[0]).type(
                torch.LongTensor)
            domain_combined_label = torch.cat(
                (domain_source_labels, domain_target_labels), 0).cuda()
            domain_loss = discriminator_criterion(domain_pred,
                                                  domain_combined_label)

            total_loss = class_loss + domain_loss
            total_loss.backward()
            optimizer.step()

            if (batch_idx + 1) % 50 == 0:
                print(
                    '[{}/{} ({:.0f}%)]\tLoss: {:.6f}\tClass Loss: {:.6f}\tDomain Loss: {:.6f}'
                    .format(batch_idx * len(target_image),
                            len(target_train_loader.dataset),
                            100. * batch_idx / len(target_train_loader),
                            total_loss.item(), class_loss.item(),
                            domain_loss.item()))

        if (epoch + 1) % 10 == 0:
            test.tester(encoder,
                        classifier,
                        discriminator,
                        source_test_loader,
                        target_test_loader,
                        training_mode='dann')

    save_model(encoder, classifier, discriminator, 'source', save_name)
    visualize(encoder, 'source', save_name)