def source_only(encoder, classifier, source_train_loader, target_train_loader, save_name): print("Source-only training") classifier_criterion = nn.CrossEntropyLoss().cuda() optimizer = optim.SGD(list(encoder.parameters()) + list(classifier.parameters()), lr=0.01, momentum=0.9) for epoch in range(params.epochs): print('Epoch : {}'.format(epoch)) set_model_mode('train', [encoder, classifier]) start_steps = epoch * len(source_train_loader) total_steps = params.epochs * len(target_train_loader) for batch_idx, (source_data, target_data) in enumerate( zip(source_train_loader, target_train_loader)): source_image, source_label = source_data p = float(batch_idx + start_steps) / total_steps source_image = torch.cat( (source_image, source_image, source_image), 1) # MNIST convert to 3 channel source_image, source_label = source_image.cuda( ), source_label.cuda() # 32 optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p) optimizer.zero_grad() source_feature = encoder(source_image) # Classification loss class_pred = classifier(source_feature) class_loss = classifier_criterion(class_pred, source_label) class_loss.backward() optimizer.step() if (batch_idx + 1) % 50 == 0: print('[{}/{} ({:.0f}%)]\tClass Loss: {:.6f}'.format( batch_idx * len(source_image), len(source_train_loader.dataset), 100. * batch_idx / len(source_train_loader), class_loss.item())) if (epoch + 1) % 10 == 0: test.tester(encoder, classifier, None, source_test_loader, target_test_loader, training_mode='source_only') save_model(encoder, classifier, None, 'source', save_name) visualize(encoder, 'source', save_name)
def forward_pass(model_dict, data, label, batch_indexes, phase, info): # prepare data in appropriate way model_dict = set_model_mode(model_dict, phase) target = data = data.to(info['device']) # input is bt 0 and 1 bs, c, h, w = data.shape z, u_q = model_dict['acn_model'](data) u_q_flat = u_q.view(bs, info['code_length']) if phase == 'train': # check that we are getting what we think we are getting from the replay # buffer assert batch_indexes.max() < model_dict['prior_model'].codes.shape[0] model_dict['prior_model'].update_codebook(batch_indexes, u_q_flat.detach()) u_p, s_p = model_dict['prior_model'](u_q_flat) u_p = u_p.view(bs, model_dict['acn_model'].bottleneck_channels, model_dict['acn_model'].eo, model_dict['acn_model'].eo) s_p = s_p.view(bs, model_dict['acn_model'].bottleneck_channels, model_dict['acn_model'].eo, model_dict['acn_model'].eo) if info['vq_decoder']: rec_dml, z_e_x, z_q_x, latents = model_dict['acn_model'].decode(z) return model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents else: rec_dml = model_dict['acn_model'].decode(z) return model_dict, data, target, rec_dml, u_q, u_p, s_p
def save_latents(l_filepath, model_dict, data_dict, info): # always be in eval mode - so we dont swap neighbors all_vq_latents = [] with torch.no_grad(): for phase in ['train', 'valid']: if not os.path.exists(l_filepath + '_%s.npz' % phase): model_dict = set_model_mode(model_dict, 'valid') data_loader = data_dict[phase] for idx, (data, label, batch_index) in enumerate(data_loader): fp_out = forward_pass(model_dict, data, label, batch_index, phase, info) if info['vq_decoder']: model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out else: model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out bs, c, h, w = data.shape u_q_flat = u_q.view(bs, info['code_length']) n_neighbors = info['num_k'] neighbor_distances, neighbor_indexes = model_dict[ 'prior_model'].kneighbors(u_q_flat, n_neighbors=n_neighbors) if not idx: all_indexes = batch_index.cpu().numpy() all_labels = label.cpu().numpy() all_acn_uq = u_q.cpu().numpy() all_neighbors = neighbor_indexes.cpu().numpy() all_neighbor_distances = neighbor_distances.cpu( ).numpy() if info['vq_decoder']: all_vq_latents = latents.cpu().numpy() else: all_indexes = np.append(all_indexes, batch_index.cpu().numpy()) all_labels = np.append(all_labels, label.cpu().numpy()) all_acn_uq = np.vstack((all_acn_uq, u_q.cpu().numpy())) all_neighbors = np.vstack( (all_neighbors, neighbor_indexes.cpu().numpy())) all_neighbor_distances = np.vstack( (all_neighbor_distances, neighbor_distances.cpu().numpy())) if info['vq_decoder']: all_vq_latents = np.vstack( (all_vq_latents, latents.cpu().numpy())) print('finished save latents', all_neighbors.shape[0]) np.savez(l_filepath + '_' + phase, index=all_indexes, labels=all_labels, acn_uq=all_acn_uq, neighbor_train_indexes=all_neighbors, neighbor_distances=all_neighbor_distances, vq_latents=all_vq_latents) train_results = np.load(l_filepath + '_train.npz') valid_results = np.load(l_filepath + '_valid.npz') return train_results, valid_results
def tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode): print("Model test ...") encoder.cuda() classifier.cuda() set_model_mode('eval', [encoder, classifier]) if training_mode == 'dann': discriminator.cuda() set_model_mode('eval', [discriminator]) domain_correct = 0 source_correct = 0 target_correct = 0 for batch_idx, (source_data, target_data) in enumerate(zip(source_test_loader, target_test_loader)): p = float(batch_idx) / len(source_test_loader) alpha = 2. / (1. + np.exp(-10 * p)) - 1 # 1. Source input -> Source Classification source_image, source_label = source_data source_image, source_label = source_image.cuda(), source_label.cuda() source_image = torch.cat((source_image, source_image, source_image), 1) # MNIST convert to 3 channel source_feature = encoder(source_image) source_output = classifier(source_feature) source_pred = source_output.data.max(1, keepdim=True)[1] source_correct += source_pred.eq(source_label.data.view_as(source_pred)).cpu().sum() # 2. Target input -> Target Classification target_image, target_label = target_data target_image, target_label = target_image.cuda(), target_label.cuda() target_feature = encoder(target_image) target_output = classifier(target_feature) target_pred = target_output.data.max(1, keepdim=True)[1] target_correct += target_pred.eq(target_label.data.view_as(target_pred)).cpu().sum() if training_mode == 'dann': # 3. Combined input -> Domain Classificaion combined_image = torch.cat((source_image, target_image), 0) # 64 = (S:32 + T:32) domain_source_labels = torch.zeros(source_label.shape[0]).type(torch.LongTensor) domain_target_labels = torch.ones(target_label.shape[0]).type(torch.LongTensor) domain_combined_label = torch.cat((domain_source_labels, domain_target_labels), 0).cuda() domain_feature = encoder(combined_image) domain_output = discriminator(domain_feature, alpha) domain_pred = domain_output.data.max(1, keepdim=True)[1] domain_correct += domain_pred.eq(domain_combined_label.data.view_as(domain_pred)).cpu().sum() if training_mode == 'dann': print("Test Results on DANN :") print('\nSource Accuracy: {}/{} ({:.2f}%)\n' 'Target Accuracy: {}/{} ({:.2f}%)\n' 'Domain Accuracy: {}/{} ({:.2f}%)\n'. format( source_correct, len(source_test_loader.dataset), 100. * source_correct.item() / len(source_test_loader.dataset), target_correct, len(target_test_loader.dataset), 100. * target_correct.item() / len(target_test_loader.dataset), domain_correct, len(source_test_loader.dataset) + len(target_test_loader.dataset), 100. * domain_correct.item() / (len(source_test_loader.dataset) + len(target_test_loader.dataset)) )) else: print("Test results on source_only :") print('\nSource Accuracy: {}/{} ({:.2f}%)\n' 'Target Accuracy: {}/{} ({:.2f}%)\n'.format( source_correct, len(source_test_loader.dataset), 100. * source_correct.item() / len(source_test_loader.dataset), target_correct, len(target_test_loader.dataset), 100. * target_correct.item() / len(target_test_loader.dataset)))
def call_plot(model_dict, data_dict, info, sample, tsne, pca): from utils import tsne_plot from utils import pca_plot from sklearn.cluster import KMeans # always be in eval mode - so we dont swap neighbors model_dict = set_model_mode(model_dict, 'valid') srandom_state = np.random.RandomState(1234) with torch.no_grad(): for phase in ['train', 'valid']: batch_index = srandom_state.randint(0, len(data_dict[phase].dataset), info['batch_size']) print(batch_index) data = torch.stack([ data_dict[phase].dataset.indexed_dataset[index][0] for index in batch_index ]) label = torch.stack([ data_dict[phase].dataset.indexed_dataset[index][1] for index in batch_index ]) batch_index = torch.LongTensor(batch_index) data = torch.FloatTensor(data) fp_out = forward_pass(model_dict, data, label, batch_index, 'valid', info) if info['vq_decoder']: model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out else: model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out bs, c, h, w = data.shape rec_yhat = sample_from_discretized_mix_logistic( rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) data = data.detach().cpu().numpy() rec = rec_yhat.detach().cpu().numpy() u_q_flat = u_q.view(bs, info['code_length']) # choose limited number to plot n = min([20, bs]) n_neighbors = args.num_k if sample: all_neighbor_distances, all_neighbor_indexes = model_dict[ 'prior_model'].kneighbors(u_q_flat, n_neighbors=n_neighbors) all_neighbor_indexes = all_neighbor_indexes.cpu().numpy() all_neighbor_distances = all_neighbor_distances.cpu().numpy() n_cols = 2 + n_neighbors tbatch_index = batch_index.cpu().numpy() np_label = label.cpu().numpy() for i in np.arange(0, n): # plot each base image plt_path = info['model_loadpath'].replace( '.pt', '_batch_rec_neighbors_%s_%06d_plt.png' % (phase, tbatch_index[i])) # bi 5136 neighbor_indexes = all_neighbor_indexes[i] code = u_q[i].view( (1, model_dict['acn_model'].bottleneck_channels, model_dict['acn_model'].eo, model_dict['acn_model'].eo)).cpu().numpy() f, ax = plt.subplots(4, n_cols) ax[0, 0].set_title('L%sI%s' % (np_label[i], tbatch_index[i])) ax[0, 0].set_ylabel('true') ax[0, 0].matshow(data[i, 0]) ax[1, 0].set_ylabel('rec') ax[1, 0].matshow(rec[i, 0]) ax[2, 0].matshow(code[0, 0]) ax[3, 0].matshow(code[0, 1]) neighbor_data = torch.stack([ data_dict['train'].dataset.indexed_dataset[index][0] for index in neighbor_indexes ]) neighbor_label = torch.stack([ data_dict['train'].dataset.indexed_dataset[index][1] for index in neighbor_indexes ]) # u_q_flat neighbor_codes_flat = model_dict['prior_model'].codes[ neighbor_indexes] neighbor_codes = neighbor_codes_flat.view( n_neighbors, model_dict['acn_model'].bottleneck_channels, model_dict['acn_model'].eo, model_dict['acn_model'].eo) if info['vq_decoder']: neighbor_rec_dml, _, _, _ = model_dict[ 'acn_model'].decode( neighbor_codes.to(info['device'])) else: neighbor_rec_dml = model_dict['acn_model'].decode( neighbor_codes.to(info['device'])) neighbor_data = neighbor_data.cpu().numpy() neighbor_rec_yhat = sample_from_discretized_mix_logistic( neighbor_rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']).cpu( ).numpy() for ni in range(n_neighbors): nindex = all_neighbor_indexes[i, ni].item() nlabel = neighbor_label[ni].cpu().numpy() ncode = neighbor_codes[ni].cpu().numpy() ax[0, ni + 2].set_title('L%sI%s' % (nlabel, nindex)) ax[0, ni + 2].matshow(neighbor_data[ni, 0]) ax[1, ni + 2].matshow(neighbor_rec_yhat[ni, 0]) ax[2, ni + 2].matshow(ncode[0]) ax[3, ni + 2].matshow(ncode[1]) ax[2, 0].set_ylabel('lc0') ax[3, 0].set_ylabel('lc1') [ax[xx, 0].set_xticks([]) for xx in range(4)] [ax[xx, 0].set_yticks([]) for xx in range(4)] for xx in range(4): [ax[xx, col].axis('off') for col in range(1, n_cols)] plt.subplots_adjust(wspace=0, hspace=0) plt.tight_layout() print('plotting', plt_path) plt.savefig(plt_path) plt.close() X = u_q_flat.cpu().numpy() #km = KMeans(n_clusters=10) #y = km.fit_predict(X) # color points based on clustering, label, or index color = label.cpu().numpy() #y #batch_indexes if tsne: param_name = '_tsne_%s_P%s.html' % (phase, info['perplexity']) html_path = info['model_loadpath'].replace('.pt', param_name) if not os.path.exists(html_path): tsne_plot(X=X, images=data[:, 0], color=color, perplexity=info['perplexity'], html_out_path=html_path, serve=False) if pca: param_name = '_pca_%s.html' % (phase) html_path = info['model_loadpath'].replace('.pt', param_name) if not os.path.exists(html_path): pca_plot(X=X, images=data[:, 0], color=color, html_out_path=html_path, serve=False)
def run(train_cnt, model_dict, data_dict, phase, info): st = time.time() loss_dict = { 'running': 0, 'kl': 0, 'rec_%s' % info['rec_loss_type']: 0, 'loss': 0, } if info['vq_decoder']: loss_dict['vq'] = 0 loss_dict['commit'] = 0 dataset = data_dict[phase] num_batches = len(dataset) // info['batch_size'] print(phase, 'num batches', num_batches) set_model_mode(model_dict, phase) torch.set_grad_enabled(phase == 'train') batch_cnt = 0 data_loader = data_dict[phase] num_batches = len(data_loader) for idx, (data, label, batch_index) in enumerate(data_loader): for key in model_dict.keys(): model_dict[key].zero_grad() fp_out = forward_pass(model_dict, data, label, batch_index, phase, info) if info['vq_decoder']: model_dict, data, target, rec_dml, u_q, u_p, s_p, z_e_x, z_q_x, latents = fp_out else: model_dict, data, target, rec_dml, u_q, u_p, s_p = fp_out bs, c, h, w = data.shape if batch_cnt == 0: log_ones = torch.zeros(bs, info['code_length']).to(info['device']) if bs != log_ones.shape[0]: log_ones = torch.zeros(bs, info['code_length']).to(info['device']) kl = kl_loss_function(u_q.view(bs, info['code_length']), log_ones, u_p.view(bs, info['code_length']), s_p.view(bs, info['code_length']), reduction=info['reduction']) rec_loss = discretized_mix_logistic_loss( rec_dml, target, nr_mix=info['nr_logistic_mix'], reduction=info['reduction']) if info['vq_decoder']: vq_loss = F.mse_loss(z_q_x, z_e_x.detach(), reduction=info['reduction']) commit_loss = F.mse_loss(z_e_x, z_q_x.detach(), reduction=info['reduction']) commit_loss *= info['vq_commitment_beta'] loss_dict['vq'] += vq_loss.detach().cpu().item() loss_dict['commit'] += commit_loss.detach().cpu().item() loss = kl + rec_loss + commit_loss + vq_loss else: loss = kl + rec_loss loss_dict['running'] += bs loss_dict['rec_%s' % info['rec_loss_type']] += rec_loss.detach().cpu().item() loss_dict['loss'] += loss.detach().cpu().item() loss_dict['kl'] += kl.detach().cpu().item() loss_dict['running'] += bs loss_dict['loss'] += loss.detach().cpu().item() loss_dict['kl'] += kl.detach().cpu().item() loss_dict['rec_%s' % info['rec_loss_type']] += rec_loss.detach().cpu().item() if phase == 'train': model_dict = clip_parameters(model_dict) loss.backward() model_dict['opt'].step() train_cnt += bs if batch_cnt == num_batches - 1: # store example near end for plotting rec_yhat = sample_from_discretized_mix_logistic( rec_dml, info['nr_logistic_mix'], only_mean=info['sample_mean'], sampling_temperature=info['sampling_temperature']) example = { 'target': data.detach().cpu().numpy(), 'rec': rec_yhat.detach().cpu().numpy(), } if not batch_cnt % 100: print(train_cnt, batch_cnt, account_losses(loss_dict)) print(phase, 'cuda', torch.cuda.memory_allocated(device=None)) batch_cnt += 1 loss_avg = account_losses(loss_dict) torch.cuda.empty_cache() print("finished %s after %s secs at cnt %s" % ( phase, time.time() - st, train_cnt, )) del data del target return loss_avg, example
def dann(encoder, classifier, discriminator, source_train_loader, target_train_loader, save_name): print("DANN training") classifier_criterion = nn.CrossEntropyLoss().cuda() discriminator_criterion = nn.CrossEntropyLoss().cuda() optimizer = optim.SGD(list(encoder.parameters()) + list(classifier.parameters()) + list(discriminator.parameters()), lr=0.01, momentum=0.9) for epoch in range(params.epochs): print('Epoch : {}'.format(epoch)) set_model_mode('train', [encoder, classifier, discriminator]) start_steps = epoch * len(source_train_loader) total_steps = params.epochs * len(target_train_loader) for batch_idx, (source_data, target_data) in enumerate( zip(source_train_loader, target_train_loader)): source_image, source_label = source_data target_image, target_label = target_data p = float(batch_idx + start_steps) / total_steps alpha = 2. / (1. + np.exp(-10 * p)) - 1 source_image = torch.cat( (source_image, source_image, source_image), 1) source_image, source_label = source_image.cuda( ), source_label.cuda() target_image, target_label = target_image.cuda( ), target_label.cuda() combined_image = torch.cat((source_image, target_image), 0) optimizer = utils.optimizer_scheduler(optimizer=optimizer, p=p) optimizer.zero_grad() combined_feature = encoder(combined_image) source_feature = encoder(source_image) # 1.Classification loss class_pred = classifier(source_feature) class_loss = classifier_criterion(class_pred, source_label) # 2. Domain loss domain_pred = discriminator(combined_feature, alpha) domain_source_labels = torch.zeros(source_label.shape[0]).type( torch.LongTensor) domain_target_labels = torch.ones(target_label.shape[0]).type( torch.LongTensor) domain_combined_label = torch.cat( (domain_source_labels, domain_target_labels), 0).cuda() domain_loss = discriminator_criterion(domain_pred, domain_combined_label) total_loss = class_loss + domain_loss total_loss.backward() optimizer.step() if (batch_idx + 1) % 50 == 0: print( '[{}/{} ({:.0f}%)]\tLoss: {:.6f}\tClass Loss: {:.6f}\tDomain Loss: {:.6f}' .format(batch_idx * len(target_image), len(target_train_loader.dataset), 100. * batch_idx / len(target_train_loader), total_loss.item(), class_loss.item(), domain_loss.item())) if (epoch + 1) % 10 == 0: test.tester(encoder, classifier, discriminator, source_test_loader, target_test_loader, training_mode='dann') save_model(encoder, classifier, discriminator, 'source', save_name) visualize(encoder, 'source', save_name)