def __init__(self, pnet_type='vgg', pnet_rand=False, use_gpu=True): super(PNet, self).__init__() self.use_gpu = use_gpu self.pnet_type = pnet_type self.pnet_rand = pnet_rand self.shift = to_var( torch.Tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) self.scale = to_var(torch.Tensor([.458, .448, .450]).view(1, 3, 1, 1)) if (self.pnet_type in ['vgg', 'vgg16']): self.net = pn.vgg16(pretrained=not self.pnet_rand, requires_grad=False) elif (self.pnet_type == 'alex'): self.net = pn.alexnet(pretrained=not self.pnet_rand, requires_grad=False) elif (self.pnet_type[:-2] == 'resnet'): self.net = pn.resnet(pretrained=not self.pnet_rand, requires_grad=False, num=int(self.pnet_type[-2:])) elif (self.pnet_type == 'squeeze'): self.net = pn.squeezenet(pretrained=not self.pnet_rand, requires_grad=False) self.L = self.net.N_slices if (use_gpu): self.net.cuda() self.shift = self.shift.cuda() self.scale = self.scale.cuda()
def LPIPS(self): from misc.utils import compute_lpips data_loader = self.data_loader n_images = 100 pair_styles = 20 model = None DISTANCE = {0: [], 1: []} self.G.eval() for i, (real_x, org_c, files) in tqdm(enumerate(data_loader), desc='Calculating LPISP', total=n_images): for _real_x, _org_c in zip(real_x, org_c): _real_x = _real_x.unsqueeze(0) _org_c = _org_c.unsqueeze(0) if len(DISTANCE[_org_c[0, 0]]) >= i: continue _real_x = to_var(_real_x, volatile=True) target_c = to_var(1 - _org_c, volatile=True) for _ in range(pair_styles): style0 = to_var(self.G.random_style(_real_x.size(0)), volatile=True) style1 = to_var(self.G.random_style(_real_x.size(0)), volatile=True) fake_x0 = self.G(_real_x, target_c, stochastic=style0) fake_x1 = self.G(_real_x, target_c, stochastic=style1) distance, model = compute_lpips(fake_x0, fake_x1, model=model) DISTANCE[org_c[0, 0]].append(distance) if i == len(DISTANCE[0, 0]) == len(DISTANCE[1]): break print("LPISP a-b: {}".format(np.array(DISTANCE[0]).mean())) print("LPISP b-a: {}".format(np.array(DISTANCE[1]).mean()))
def debug(self): feed = to_var(torch.ones(1, self.color_dim, self.image_size, self.image_size), volatile=True, no_cuda=True) label = to_var(torch.ones(1, self.c_dim), volatile=True, no_cuda=True) style = to_var(self.random_style(feed), volatile=True, no_cuda=True) self.apply_style(feed, label, style) self.generator.debug()
def save_multidomain_output(self, real_x, label, save_path, **kwargs): self.G.eval() self.D.eval() no_grad = open('/var/tmp/null.txt', 'w') if get_torch_version() < 1.0 else torch.no_grad() with no_grad: real_x = to_var(real_x, volatile=True) n_style = self.config.style_debug n_interp = self.config.n_interpolation + 10 _name = 'domain_interpolation' no_label = True for idx in range(n_style): dirname = save_path.replace('.jpg', '') filename = '{}_style{}.jpg'.format(_name, str(idx + 1).zfill(2)) _save_path = os.path.join(dirname, filename) create_dir(_save_path) fake_image_list, fake_attn_list = self.Create_Visual_List( real_x) style = self.G.random_style(1).repeat(real_x.size(0), 1) style = to_var(style, volatile=True) label0 = to_var(label, volatile=True) opposite_label = self.target_multiAttr(1 - label, 2) # 2: black hair opposite_label[:, 7] = 0 # Pale skin label1 = to_var(opposite_label, volatile=True) labels = [label0, label1] styles = [style, style] domain_interp = self.MMInterpolation(labels, styles, n_interp=n_interp) for target_de in domain_interp[5:]: # target_de = target_de.repeat(real_x.size(0), 1) target_de = to_var(target_de, volatile=True) fake_x = self.G(real_x, target_de, style, DE=target_de) fake_image_list.append(to_data(fake_x[0], cpu=True)) fake_attn_list.append( to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True)) self._SAVE_IMAGE(_save_path, fake_image_list, no_label=no_label, arrow=False, circle=False) self._SAVE_IMAGE(_save_path, fake_attn_list, Attention=True, arrow=False, no_label=no_label, circle=False) self.G.train() self.D.train()
def debug(self): PRINT(self.config.log, '-- Generator:') feed = to_var(torch.ones(1, self.color_dim, self.image_size, self.image_size), volatile=True, no_cuda=True) features = self.print_debug(feed, self.main) self.print_debug(features, self.fake) self.print_debug(features, self.attn)
def debug(self): feed = to_var( torch.ones(1, self.color_dim, self.image_size, self.image_size), volatile=True, no_cuda=True) PRINT(self.config.log, '-- StyleEncoder:') features = self.print_debug(feed, self.main) fc_in = features.view(features.size(0), -1) self.print_debug(fc_in, self.fc)
def _CLS(self, data): data = to_var(data, volatile=True) out_label = self.D(data)[1] if len(out_label) > 1: out_label = torch.cat( [F.sigmoid(out.unsqueeze(-1)) for out in out_label], dim=-1).mean(dim=-1) else: out_label = F.sigmoid(out_label[0]) out_label = (out_label > 0.5).float() return out_label
def debug(self): feed = to_var(torch.ones(1, self.color_dim, self.image_size, self.image_size), volatile=True, no_cuda=True) modelList = zip(self.cnns_main, self.cnns_src, self.cnns_aux) for idx, outs in enumerate(modelList): PRINT(self.config.log, '-- MultiDiscriminator ({}):'.format(idx)) features = self.print_debug(feed, outs[-3]) self.print_debug(features, outs[-2]) self.print_debug(features, outs[-1]).view(feed.size(0), -1) feed = self.downsample(feed)
def MMInterpolation(self, targets, styles, n_interp=None): assert len(targets) == 2 and len(styles) == 2 if n_interp is None: n_interp = self.config.n_interpolation in_de0 = self.label2embedding(targets[0], styles[0]) in_de1 = self.label2embedding(targets[1], styles[1]) domain_interp = torch.zeros((n_interp, targets[0].size(0), in_de0.shape[-1])) domain_interp = to_var(domain_interp, volatile=True) for i in range(targets[0].size(0)): domain_interp[:, i] = interpolation(in_de0[i], in_de1[i], n_interp) return domain_interp
def fit(self, configs): self.base_model.train() dataloader, optimizer = configs['dataloader'], configs['optimizer'] try: flag = True total_steps = len(dataloader) except: flag = False total_steps = 1 current_epoch = configs['current_epoch'] total_epochs = configs['total_epochs'] teacher_updates = configs.get('policy_step', -1) logger = configs['logger'] all_correct = 0 all_samples = 0 loss_average = 0 for idx, (inputs, labels) in enumerate(dataloader): optimizer.zero_grad() if flag: inputs = to_var(inputs) labels = to_var(labels) predicts = self.base_model(inputs) eval_res = self.evaluator(predicts, labels) num_correct = eval_res['num_correct'] num_samples = eval_res['num_samples'] # logger.info('num_samples %d, num_correct %d'%(num_samples, num_correct)) loss = eval_res['loss'] all_correct += num_correct all_samples += num_samples loss.backward() optimizer.step() logger.info('Policy Steps: [%d] Train: ----- Iteration [%d], loss: %5.4f, accuracy: %5.4f(%5.4f)' % ( teacher_updates, current_epoch+1, loss.cpu().data[0], num_correct/num_samples, all_correct/all_samples)) loss_average += loss.cpu().data[0] return loss_average/total_steps
def Gen_update(self, real_x, real_c, fake_c): self.train_model(generator=True) real_x, real_c, fake_c = self.to_var(real_x, real_c, fake_c) criterion_l1 = torch.nn.L1Loss() style_fake = to_var(self.random_style(real_x, seed=self.count_seed)) style_rec = to_var(self.random_style(real_x, seed=self.count_seed + 1)) style_identity = to_var( self.random_style(real_x, seed=self.count_seed + 2)) self.count_seed += 3 fake_x = self.G(real_x, fake_c, style_fake) g_loss_src, g_loss_cls = self._GAN_LOSS(fake_x[0], real_x, fake_c) self.loss['Gsrc'] = g_loss_src self.loss['Gcls'] = g_loss_cls * self.config.lambda_cls # REC LOSS rec_x = self.G(fake_x[0], real_c, style_rec) g_loss_rec = criterion_l1(rec_x[0], real_x) self.loss['Grec'] = self.config.lambda_rec * g_loss_rec # ========== Attention Part ==========# self.loss['Gatm'] = self.config.lambda_mask * (torch.mean(rec_x[1]) + torch.mean(fake_x[1])) self.loss['Gats'] = self.config.lambda_mask_smooth * ( _compute_loss_smooth(rec_x[1]) + _compute_loss_smooth(fake_x[1])) # ========== Identity Part ==========# if self.config.Identity: idt_x = self.G(real_x, real_c, style_identity)[0] g_loss_idt = criterion_l1(idt_x, real_x) self.loss['Gidt'] = self.config.lambda_idt * \ g_loss_idt g_loss = self.current_losses('G', **self.loss) self.reset_grad() g_loss.backward() self.g_optimizer.step()
def Dis_update(self, real_x, real_c, fake_c): self.train_model(discriminator=True) real_x, real_c, fake_c = self.to_var(real_x, real_c, fake_c) style_fake = to_var(self.random_style(real_x, seed=self.count_seed)) self.count_seed += 1 fake_x = self.G(real_x, fake_c, style_fake)[0] d_loss_src, d_loss_cls = self._GAN_LOSS(real_x, fake_x, real_c) self.loss['Dsrc'] = d_loss_src self.loss['Dcls'] = d_loss_cls * self.config.lambda_cls d_loss = self.current_losses('D', **self.loss) self.reset_grad() d_loss.backward() self.d_optimizer.step()
def val(dataloader, model, val_info, criterion): model.eval() num_epochs = val_info['num_epochs'] epoch = val_info['epoch'] total_steps = len(dataloader) total_loss = 0 for idx, (x_train, x_predict) in enumerate(dataloader): x_train = to_var(x_train) x_predict = to_var(x_predict) data = pack([x_train, x_predict, None], ['x_train', 'x_predict', 'states']) configs = pack([False, 10], ['use_gt', 'max_steps']) reconstruct, predict = model(data, configs) r_loss = criterion(reconstruct, x_train) p_loss = criterion(predict, x_predict) loss = r_loss + p_loss logger.info( '[Val] Epoch [%d/%d], Step [%d/%d], Reconstruct Loss: %5.4f, Predict Loss: %5.4f, Total: %5.4f' % (epoch, num_epochs, idx + 1, total_steps, r_loss.data[0], p_loss.data[0], loss.data[0])) total_loss += loss.data[0] return total_loss / total_steps
def val(self, configs): self.base_model.eval() dataloader = configs['dataloader'] total_steps = len(dataloader) all_correct = 0 all_samples = 0 loss_average = 0 for idx, (inputs, labels) in enumerate(dataloader): inputs = to_var(inputs, volatile=True) labels = to_var(labels) predicts = self.base_model(inputs) eval_res = self.evaluator(predicts, labels) num_correct = eval_res['num_correct'] num_samples = eval_res['num_samples'] all_correct += num_correct all_samples += num_samples # logger.info('Eval: Epoch [%d/%d], Iteration [%d/%d], accuracy: %5.4f(%5.4f)' % ( # current_epoch, total_epochs, idx, total_steps, num_correct/num_samples, all_correct/all_samples)) loss_average += eval_res['loss'].cpu().data[0] # print ('Total: %d, correct: %d', all_samples, all_correct) return all_correct/all_samples, loss_average/total_steps
def train(dataloader, model, optimizer, criterion, train_info): model.train() num_epochs = train_info['num_epochs'] epoch = train_info['epoch'] clip = train_info['clip'] total_steps = len(dataloader) for idx, (x_train, x_predict) in enumerate(dataloader): optimizer.zero_grad() x_train = to_var(x_train) x_predict = to_var(x_predict) data = pack([x_train, x_predict, None], ['x_train', 'x_predict', 'states']) configs = pack([False, 10], ['use_gt', 'max_steps']) reconstruct, predict = model(data, configs) r_loss = criterion(reconstruct, x_train) p_loss = criterion(predict, x_predict) loss = r_loss + p_loss loss.backward() torch.nn.utils.clip_grad_norm(model.parameters(), clip) optimizer.step() logger.info( 'Epoch [%d/%d], Step [%d/%d], Reconstruct Loss: %5.4f, Predict Loss: %5.4f, Total: %5.4f' % (epoch, num_epochs, idx + 1, total_steps, r_loss.data[0], p_loss.data[0], loss.data[0]))
def forward(self, in0, in1): assert (in0.size()[0] == 1) # currently only supports batchSize 1 if (self.colorspace == 'RGB'): value = util.dssim(1. * util.tensor2im(in0.data), 1. * util.tensor2im(in1.data), range=255.).astype('float') elif (self.colorspace == 'Lab'): value = util.dssim( util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') ret_var = to_var(torch.Tensor((value, ))) if (self.use_gpu): ret_var = ret_var.cuda() return ret_var
def INCEPTION_REAL(self): from misc.utils import load_inception from scipy.stats import entropy net = load_inception() net = to_cuda(net) net.eval() inception_up = nn.Upsample(size=(299, 299), mode='bilinear') mode = 'Real' data_loader = self.data_loader file_name = 'scores/Inception_{}.txt'.format(mode) PRED_IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))} IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))} for i, (real_x, org_c, files) in tqdm( enumerate(data_loader), desc='Calculating CIS/IS - {}'.format(file_name), total=len(data_loader)): label = torch.max(org_c, 1)[1][0] real_x = to_var((real_x + 1) / 2., volatile=True) pred = to_data(F.softmax(net(inception_up(real_x)), dim=1), cpu=True).numpy() PRED_IS[int(label)].append(pred) for label in range(len(data_loader.dataset.labels[0])): PRED_IS[label] = np.concatenate(PRED_IS[label], 0) # prior is computed from all outputs py = np.sum(PRED_IS[label], axis=0) for j in range(PRED_IS[label].shape[0]): pyx = PRED_IS[label][j, :] IS[label].append(entropy(pyx, py)) total_is = [] file_ = open(file_name, 'w') for label in range(len(data_loader.dataset.labels[0])): _is = np.exp(np.mean(IS[label])) total_is.append(_is) PRINT(file_, "Label {}".format(label)) PRINT(file_, "Inception Score: {:.4f}".format(_is)) PRINT(file_, "") PRINT( file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format( np.array(total_is).mean(), np.array(total_is).std())) file_.close()
def forward(self, in0, in1): assert (in0.size()[0] == 1) # currently only supports batchSize 1 if (self.colorspace == 'RGB'): (N, C, X, Y) = in0.size() value = torch.mean(torch.mean(torch.mean((in0 - in1)**2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3).view(N) return value elif (self.colorspace == 'Lab'): value = util.l2( util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), range=100.).astype('float') ret_var = to_var(torch.Tensor((value, ))) if (self.use_gpu): ret_var = ret_var.cuda() return ret_var
def forward(self, data, configs=None): _KEYS = ['x', 'states'] x, states = unpack(data, _KEYS) batch_size = states[0][0].size(0) if x is None: x_c, x_h, x_w = self.cell_config['in_c'], self.cell_config[ 'in_w'], self.cell_config['in_h'] x = to_var(torch.zeros(batch_size, 1, x_c, x_h, x_w)) # x: batch_size, time_steps, channels, height, width time_steps = x.size(1) next_states = [] cell_list = self.cell_list current_input = [x[:, t] for t in xrange(time_steps)] for l in xrange(self.num_layers): h0, c0 = states[l] for t in xrange(time_steps): data = pack([current_input[t], (h0, c0)], ['x', 'states']) h, c = cell_list[l](data) next_states.append((h, c)) current_input[t] = h states[l] = (h, c) return states
def Modality(self, target, style, Multimodality, idx=0): _size = target.size(0) if self.config.dataset_fake in self.MultiLabel_Datasets: target = (self.org_label - target)**2 # Swap labels target = self.target_multiAttr(target, idx) target = to_var(target, volatile=True) if Multimodality == 1: # Random Styles domain_embedding = self.label2embedding(target, style, _torch=True) elif Multimodality == 2: # Style interpolation | Fixed Labels # The batch belongs to the same image style0 = style[0].repeat(_size, 1) style1 = style[1].repeat(_size, 1) targets = [target, target] styles = [style0, style1] domain_embedding = self.MMInterpolation(targets, styles)[:, 0] elif Multimodality == 3: # Style constant | Progressive swap label n_interp = self.config.n_interpolation + 5 target0 = self.org_label target1 = target style = style[0].repeat(_size, 1) targets = [target0, target1] styles = [style, style] domain_embedding = self.MMInterpolation(targets, styles, n_interp)[5:, 0] else: # Unimodal style = style[0].repeat(_size, 1) domain_embedding = self.label2embedding(target, style, _torch=True) return domain_embedding
def to_var(self, *args): vars = [] for arg in args: vars.append(to_var(arg)) return vars
def init_hidden(self, batch_size, cuda=False): return (to_var((torch.zeros(batch_size, self.h_c, self.in_h, self.in_w))), to_var(torch.zeros(batch_size, self.h_c, self.in_h, self.in_w)))
def forward(self, in0, in1, retNumpy=True): ''' Function computes the distance between image patches in0 and in1 INPUTS in0, in1 - torch.Tensor object of shape Nx3xXxY - i mage patch scaled to [-1,1] retNumpy - [False] to return as torch.Tensor, [True] to return as numpy array OUTPUT computed distances between in0 and in1 ''' self.input_ref = in0 self.input_p0 = in1 if (self.use_gpu): self.input_ref = self.input_ref.cuda() self.input_p0 = self.input_p0.cuda() self.var_ref = to_var(self.input_ref, requires_grad=True) self.var_p0 = to_var(self.input_p0, requires_grad=True) self.d0 = self.forward_pair(self.var_ref, self.var_p0) self.loss_total = self.d0 def convert_output(d0): if (retNumpy): ans = d0.cpu().data.numpy() if not self.spatial: ans = ans.flatten() else: assert (ans.shape[0] == 1 and len(ans.shape) == 4) # Reshape to usual numpy image format: (height, width, # channels) return ans[0, ...].transpose([1, 2, 0]) return ans else: return d0 if self.spatial: L = [convert_output(x) for x in self.d0] spatial_shape = self.spatial_shape if spatial_shape is None: if (self.spatial_factor is None): spatial_shape = (in0.size()[2], in0.size()[3]) else: spatial_shape = (max([x.shape[0] for x in L]) * self.spatial_factor, max([x.shape[1] for x in L]) * self.spatial_factor) L = [ skimage.transform.resize(x, spatial_shape, order=self.spatial_order, mode='edge') for x in L ] L = np.mean(np.concatenate(L, 2) * len(L), 2) return L else: return convert_output(self.d0)
def train_inception(batch_size, shuffling=False, num_workers=4, **kwargs): from torchvision.models import inception_v3 from misc.utils import to_var, to_cuda, to_data from torchvision import transforms from torch.utils.data import DataLoader import torch.nn.functional as F import torch import torch.nn as nn import tqdm metadata_path = os.path.join('data', 'RafD', 'normal') # inception Norm image_size = 299 transform = [] window = int(image_size / 10) transform += [ transforms.Resize((image_size + window, image_size + window), interpolation=Image.ANTIALIAS) ] transform += [ transforms.RandomResizedCrop(image_size, scale=(0.7, 1.0), ratio=(0.8, 1.2)) ] transform += [transforms.RandomHorizontalFlip()] transform += [transforms.ToTensor()] transform = transforms.Compose(transform) dataset_train = RafD(image_size, metadata_path, transform, 'train', shuffling=True, **kwargs) dataset_test = RafD(image_size, metadata_path, transform, 'test', shuffling=False, **kwargs) train_loader = DataLoader(dataset=dataset_train, batch_size=batch_size, shuffle=False, num_workers=num_workers) test_loader = DataLoader(dataset=dataset_test, batch_size=batch_size, shuffle=False, num_workers=num_workers) num_labels = len(train_loader.dataset.labels[0]) n_epochs = 100 net = inception_v3(pretrained=True, transform_input=True) net.aux_logits = False num_ftrs = net.fc.in_features net.fc = nn.Linear(num_ftrs, num_labels) net_save = metadata_path + '/inception_v3/{}.pth' if not os.path.isdir(os.path.dirname(net_save)): os.makedirs(os.path.dirname(net_save)) print("Model will be saved at: " + net_save) optimizer = torch.optim.RMSprop(net.parameters(), lr=1e-5) # loss = F.cross_entropy(output, target) to_cuda(net) for epoch in range(n_epochs): LOSS = {'train': [], 'test': []} OUTPUT = {'train': [], 'test': []} LABEL = {'train': [], 'test': []} net.eval() for i, (data, label, files) in tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Validating Inception V3 | RafD'): data = to_var(data, volatile=True) label = to_var(torch.max(label, dim=1)[1], volatile=True) out = net(data) loss = F.cross_entropy(out, label) # ipdb.set_trace() LOSS['test'].append(to_data(loss, cpu=True)[0]) OUTPUT['test'].extend( to_data(F.softmax(out, dim=1).max(1)[1], cpu=True).tolist()) LABEL['test'].extend(to_data(label, cpu=True).tolist()) acc_test = (np.array(OUTPUT['test']) == np.array(LABEL['test'])).mean() print('[Test] Loss: {:.4f} Acc: {:.4f}'.format( np.array(LOSS['test']).mean(), acc_test)) net.train() for i, (data, label, files) in tqdm.tqdm( enumerate(train_loader), total=len(train_loader), desc='[{}/{}] Train Inception V3 | RafD'.format( str(epoch).zfill(5), str(n_epochs).zfill(5))): # ipdb.set_trace() data = to_var(data) label = to_var(torch.max(label, dim=1)[1]) out = net(data) # ipdb.set_trace() loss = F.cross_entropy(out, label) optimizer.zero_grad() loss.backward() optimizer.step() LOSS['train'].append(to_data(loss, cpu=True)[0]) OUTPUT['train'].extend( to_data(F.softmax(out, dim=1).max(1)[1], cpu=True).tolist()) LABEL['train'].extend(to_data(label, cpu=True).tolist()) acc_train = (np.array(OUTPUT['train']) == np.array( LABEL['train'])).mean() print('[Train] Loss: {:.4f} Acc: {:.4f}'.format( np.array(LOSS['train']).mean(), acc_train)) torch.save(net.state_dict(), net_save.format(str(epoch).zfill(5))) train_loader.dataset.shuffle(epoch)
def INCEPTION(self): from misc.utils import load_inception from scipy.stats import entropy n_styles = 20 net = load_inception() net = to_cuda(net) net.eval() self.G.eval() inception_up = nn.Upsample(size=(299, 299), mode='bilinear') mode = 'SMIT' data_loader = self.data_loader file_name = 'scores/Inception_{}.txt'.format(mode) PRED_IS = {i: [] for i in range(len(data_loader.dataset.labels[0])) } # 0:[], 1:[], 2:[]} CIS = {i: [] for i in range(len(data_loader.dataset.labels[0]))} IS = {i: [] for i in range(len(data_loader.dataset.labels[0]))} for i, (real_x, org_c, files) in tqdm( enumerate(data_loader), desc='Calculating CIS/IS - {}'.format(file_name), total=len(data_loader)): PRED_CIS = { i: [] for i in range(len(data_loader.dataset.labels[0])) } # 0:[], 1:[], 2:[]} org_label = torch.max(org_c, 1)[1][0] real_x = real_x.repeat(n_styles, 1, 1, 1) # .unsqueeze(0) real_x = to_var(real_x, volatile=True) target_c = (org_c * 0).repeat(n_styles, 1) target_c = to_var(target_c, volatile=True) for label in range(len(data_loader.dataset.labels[0])): if org_label == label: continue target_c *= 0 target_c[:, label] = 1 style = to_var(self.G.random_style(n_styles), volatile=True) if mode == 'SMIT' else None fake = (self.G(real_x, target_c, style)[0] + 1) / 2 pred = to_data(F.softmax(net(inception_up(fake)), dim=1), cpu=True).numpy() PRED_CIS[label].append(pred) PRED_IS[label].append(pred) # CIS for each image PRED_CIS[label] = np.concatenate(PRED_CIS[label], 0) py = np.sum( PRED_CIS[label], axis=0 ) # prior is computed from outputs given a specific input for j in range(PRED_CIS[label].shape[0]): pyx = PRED_CIS[label][j, :] CIS[label].append(entropy(pyx, py)) for label in range(len(data_loader.dataset.labels[0])): PRED_IS[label] = np.concatenate(PRED_IS[label], 0) py = np.sum(PRED_IS[label], axis=0) # prior is computed from all outputs for j in range(PRED_IS[label].shape[0]): pyx = PRED_IS[label][j, :] IS[label].append(entropy(pyx, py)) total_cis = [] total_is = [] file_ = open(file_name, 'w') for label in range(len(data_loader.dataset.labels[0])): cis = np.exp(np.mean(CIS[label])) total_cis.append(cis) _is = np.exp(np.mean(IS[label])) total_is.append(_is) PRINT(file_, "Label {}".format(label)) PRINT(file_, "Inception Score: {:.4f}".format(_is)) PRINT(file_, "conditional Inception Score: {:.4f}".format(cis)) PRINT(file_, "") PRINT( file_, "[TOTAL] Inception Score: {:.4f} +/- {:.4f}".format( np.array(total_is).mean(), np.array(total_is).std())) PRINT( file_, "[TOTAL] conditional Inception Score: {:.4f} +/- {:.4f}".format( np.array(total_cis).mean(), np.array(total_cis).std())) file_.close()
def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, use_gpu=True, spatial=False, version='0.1'): super(PNetLin, self).__init__() self.use_gpu = use_gpu self.pnet_type = pnet_type self.pnet_tune = pnet_tune self.pnet_rand = pnet_rand self.spatial = spatial self.version = version if (self.pnet_type in ['vgg', 'vgg16']): net_type = pn.vgg16 self.chns = [64, 128, 256, 512, 512] elif (self.pnet_type == 'alex'): net_type = pn.alexnet self.chns = [64, 192, 384, 256, 256] elif (self.pnet_type == 'squeeze'): net_type = pn.squeezenet self.chns = [64, 128, 256, 384, 384, 512, 512] if (self.pnet_tune): self.net = net_type(pretrained=not self.pnet_rand, requires_grad=True) else: self.net = [ net_type(pretrained=not self.pnet_rand, requires_grad=False), ] self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] if (self.pnet_type == 'squeeze'): # 7 layers for squeezenet self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) self.lins += [self.lin5, self.lin6] self.shift = to_var( torch.Tensor([-.030, -.088, -.188]).view(1, 3, 1, 1)) self.scale = to_var(torch.Tensor([.458, .448, .450]).view(1, 3, 1, 1)) if (use_gpu): if (self.pnet_tune): self.net.cuda() else: self.net[0].cuda() self.shift = self.shift.cuda() self.scale = self.scale.cuda() self.lin0.cuda() self.lin1.cuda() self.lin2.cuda() self.lin3.cuda() self.lin4.cuda() if (self.pnet_type == 'squeeze'): self.lin5.cuda() self.lin6.cuda()
def save_multimodal_output(self, real_x, label, save_path, interpolation=False, **kwargs): self.G.eval() self.D.eval() n_rep = 4 no_label = self.config.dataset_fake in self.Binary_Datasets no_grad = open('/var/tmp/null.txt', 'w') if get_torch_version() < 1.0 else torch.no_grad() with no_grad: real_x = to_var(real_x, volatile=True) out_label = to_var(label, volatile=True) # target_c_list = [out_label] * 7 for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)): _name = 'multimodal' if interpolation == 1: _name += '_interp' elif interpolation == 2: _name = 'multidomain_interp' _save_path = os.path.join( save_path.replace('.jpg', ''), '{}_{}.jpg'.format(_name, str(idx).zfill(4))) create_dir(_save_path) real_x0 = real_x0.repeat(n_rep, 1, 1, 1) real_c0 = real_c0.repeat(n_rep, 1) fake_image_list, fake_attn_list = self.Create_Visual_List( real_x0, Multimodal=True) target_c_list = [real_c0] * 7 for _, target_c in enumerate(target_c_list): if interpolation == 0: style_ = to_var(self.G.random_style(n_rep), volatile=True) embeddings = self.label2embedding(target_c, style_, _torch=True) elif interpolation == 1: style_ = to_var(self.G.random_style(1), volatile=True) style1 = to_var(self.G.random_style(1), volatile=True) _target_c = target_c[0].unsqueeze(0) styles = [style_, style1] targets = [_target_c, _target_c] embeddings = self.MMInterpolation(targets, styles, n_interp=n_rep)[:, 0] elif interpolation == 2: style_ = to_var(self.G.random_style(1), volatile=True) target0 = 1 - target_c[0].unsqueeze(0) target1 = target_c[0].unsqueeze(0) styles = [style_, style_] targets = [target0, target1] # import ipdb; ipdb.set_trace() embeddings = self.MMInterpolation(targets, styles, n_interp=n_rep)[:, 0] else: raise ValueError( "There are only 2 types of interpolation:\ Multimodal and Multi-domain") fake_x = self.G(real_x0, target_c, style_, DE=embeddings) fake_image_list.append(to_data(fake_x[0], cpu=True)) fake_attn_list.append( to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True)) self._SAVE_IMAGE(_save_path, fake_image_list, mode='style_' + chr(65 + idx), no_label=no_label, arrow=interpolation, circle=False) self._SAVE_IMAGE(_save_path, fake_attn_list, Attention=True, mode='style_' + chr(65 + idx), arrow=interpolation, no_label=no_label, circle=False) self.G.train() self.D.train()
def save_multimodal_output(self, real_x, label, save_path, interpolation=False, **kwargs): self.G.eval() self.D.eval() n_rep = 4 no_label = self.config.dataset_fake in self.Binary_Datasets no_grad = open('/var/tmp/null.txt', 'w') if get_torch_version() < 1.0 else torch.no_grad() with no_grad: real_x = to_var(real_x, volatile=True) out_label = to_var(label, volatile=True) # target_c_list = [out_label] * 7 for idx, (real_x0, real_c0) in enumerate(zip(real_x, out_label)): _name = 'multimodal' if interpolation: _name = _name + '_interp' _save_path = os.path.join( save_path.replace('.jpg', ''), '{}_{}.jpg'.format( _name, str(idx).zfill(4))) create_dir(_save_path) real_x0 = real_x0.repeat(n_rep, 1, 1, 1) real_c0 = real_c0.repeat(n_rep, 1, 1, 1) fake_image_list = [ to_data( color_frame( single_source(real_x0), thick=5, color='green', first=True), cpu=True) ] fake_attn_list = [ to_data( color_frame( single_source(real_x0), thick=5, color='green', first=True), cpu=True) ] target_c_list = [real_c0] * 7 for _, target_c in enumerate(target_c_list): # target_c = _target_c#[0].repeat(n_rep, 1) if not interpolation: style_ = self.G.random_style(n_rep) else: z0 = to_data( self.G.random_style(1), cpu=True).numpy()[0] z1 = to_data( self.G.random_style(1), cpu=True).numpy()[0] style_ = self.G.random_style(n_rep) style_[:] = torch.FloatTensor( np.array([ slerp(sz, z0, z1) for sz in np.linspace(0, 1, n_rep) ])) style = to_var(style_, volatile=True) fake_x = self.G(real_x0, target_c, stochastic=style) fake_image_list.append(to_data(fake_x[0], cpu=True)) fake_attn_list.append( to_data(fake_x[1].repeat(1, 3, 1, 1), cpu=True)) self._SAVE_IMAGE( _save_path, fake_image_list, mode='style_' + chr(65 + idx), no_label=no_label, arrow=interpolation, circle=False) self._SAVE_IMAGE( _save_path, fake_attn_list, Attention=True, mode='style_' + chr(65 + idx), arrow=interpolation, no_label=no_label, circle=False) self.G.train() self.D.train()
def val_teacher(self, configs): # TODO: test for the policy. Plotting the curve of #effective_samples-test_accuracy ''' :param configs: Required: state_func dataloader: student/dev/test optimizer: student lr_scheduler: student logger Optional: threshold M num_classes max_t (Note: should be consistent with training) :return: ''' teacher = self.teacher_net # ==================== train student from scratch ============ init_params(self.student_net) student = self.student_net # ==================== fetch configs [optional] =============== threshold = configs.get('threshold', 0.5) M = configs.get('M', 128) num_classes = configs.get('num_classes', 10) max_t = configs.get('max_t', 50000) # =================== fetch configs [required] ================ state_func = configs['state_func'] student_dataloader = configs['dataloader']['student'] dev_dataloader = configs['dataloader']['dev'] test_dataloader = configs['dataloader']['test'] student_optimizer = configs['optimizer']['student'] student_lr_scheduler = configs['lr_scheduler']['student'] logger = configs['logger'] # ================== init tracking history ==================== training_loss_history = [] val_loss_history = [] student_updates = 0 best_acc_on_dev = 0 best_acc_on_test = 0 i_tau = 0 effective_num = 0 effnum_acc_curves = [] while i_tau < max_t: i_tau += 1 count = 0 input_pool = [] label_pool = [] # ================== collect training batch ============ for idx, (inputs, labels) in enumerate(student_dataloader): inputs = to_var(inputs) labels = to_var(labels) state_configs = { 'num_classes': num_classes, 'labels': labels, 'inputs': inputs, 'student': student, 'current_iter': i_tau, 'max_iter': max_t, 'train_loss_history': training_loss_history, 'val_loss_history': val_loss_history } states = state_func( state_configs ) # TODO: implement the function for computing state _inputs = {'input': states} predicts = teacher(_inputs, None) indices = torch.nonzero(predicts.data.squeeze() >= threshold) if len(indices) == 0: continue count += len(indices) # selected_inputs = torch.gather(inputs, 0, indices.squeeze()).view(len(indices), # *inputs.size()[1:]) # selected_labels = torch.gather(labels, 0, indices.squeeze()).view(-1, 1) # import pdb # pdb.set_trace() selected_inputs = inputs[indices.squeeze()].view( len(indices), *inputs.size()[1:]) selected_labels = labels[indices.squeeze()].view(-1, 1) input_pool.append(selected_inputs) label_pool.append(selected_labels) if count >= M: effective_num += count break # ================== prepare training data ============= inputs = torch.cat(input_pool, 0) labels = torch.cat(label_pool, 0) st_configs = { 'dataloader': to_generator([inputs, labels]), 'optimizer': student_optimizer, 'current_epoch': student_updates, 'total_epochs': -1, 'logger': logger } # ================= feed the selected batch ============ train_loss = student.fit(st_configs) training_loss_history.append(train_loss) student_updates += 1 student_lr_scheduler(student_optimizer, student_updates) # ================ test on dev set ===================== st_configs['dataloader'] = dev_dataloader acc, val_loss = student.val(st_configs) best_acc_on_dev = acc if best_acc_on_dev < acc else best_acc_on_dev logger.info( 'Test on Dev: Iteration [%d], accuracy: %5.4f, best: %5.4f' % (student_updates, acc, best_acc_on_dev)) val_loss_history.append(val_loss) # =============== test on test set ====================== st_configs['dataloader'] = test_dataloader acc, test_loss = student.val(st_configs) best_acc_on_test = acc if best_acc_on_test < acc else best_acc_on_test logger.info( 'Testing Set: Iteration [%d], accuracy: %5.4f, best: %5.4f' % (student_updates, acc, best_acc_on_test)) effnum_acc_curves.append((effective_num, acc)) return effnum_acc_curves
def fit_teacher(self, configs): ''' :param configs: Required: state_func: [function] used to compute the state vector dataloader: [dict] teacher: teacher training data loader student: student training data loader dev: for testing the student model so as to compute reward for the teacher test: student testing data loader optimizer: [dict] teacher: the optimizer for teacher student: the optimizer for student lr_scheduler: [dict] teahcer: the learning rate scheduler for the teacher model student: the learning rate scheduler for the student model <del>current_epoch: [int] the current epoch</del> <del>total_epochs: the max number of epochs to train the model</del> logger: the logger Optional: max_t: [int] [50,000] the maximum number iterations before stopping the teaching , and once reach this number, return a reward 0. tau: [float32] [0.8] the expected accuracy of the student model on dev set threshold: [float32] [0.5] the probability threshold for choosing a sample. M: [int] [128] the required batch-size for training the student model. max_non_increasing_steps: [int] [10] The maximum number of iterations of the reward not increasing. If exceeds it, stop training the teacher model. num_classes: [int] [10] the number of classes in the training set. :return: ''' teacher = self.teacher_net student = self.student_net # ==================== fetch configs [optional] =============== max_t = configs['max_t'] tau = configs['tau'] M = configs['M'] max_non_increasing_steps = configs['max_non_increasing_steps'] num_classes = configs['num_classes'] # =================== fetch configs [required] ================ state_func = configs['state_func'] teacher_dataloader = configs['dataloader']['teacher'] dev_dataloader = configs['dataloader']['dev'] teacher_optimizer = configs['optimizer']['teacher'] student_optimizer = configs['optimizer']['student'] teacher_lr_scheduler = configs['lr_scheduler']['teacher'] student_lr_scheduler = configs['lr_scheduler']['student'] logger = configs['logger'] # ================== init tracking history ==================== rewards = [] training_loss_history = [] val_loss_history = [] num_steps_to_achieve = [] non_increasing_steps = 0 student_updates = 0 teacher_updates = 0 best_acc_on_dev = 0 while True: i_tau = 0 actions = [] def overloaded_init_params(x): init_params(x) # if pointer == 0: # init_params(x) # else: # file_name = './model/resnet34-%5.4f.pth.tar' % (tau_list[pointer - 1]) # logger.info('Loaded model from' + file_name) # x.load_state_dict(torch.load(file_name)['state_dict']) while i_tau < max_t: i_tau += 1 count = 0 input_pool = [] label_pool = [] # ================== collect training batch ============ while True: for idx, (inputs, labels) in enumerate(teacher_dataloader): inputs = to_var(inputs) labels = to_var(labels) state_configs = { 'num_classes': num_classes, 'labels': labels, 'inputs': inputs, 'student': student.train(), 'current_iter': i_tau, 'max_iter': max_t, 'train_loss_history': training_loss_history, 'val_loss_history': val_loss_history } states = state_func( state_configs ) # TODO: implement the function for computing state _inputs = {'input': states.detach()} predicts = teacher(_inputs, None) sampled_actions = torch.bernoulli( predicts.data.squeeze()) indices = torch.nonzero(sampled_actions) if len(indices) == 0: #print (predicts.data.squeeze()) continue # print ('Selected %d/%d samples'%(len(indices), len(labels))) count += len(indices) selected_inputs = inputs[indices.squeeze()].view( len(indices), *inputs.size()[1:]) selected_labels = labels[indices.squeeze()].view(-1, 1) input_pool.append(selected_inputs) label_pool.append(selected_labels) actions.append( torch.log(predicts.squeeze()) * to_var(sampled_actions - 0.5) * 2) if count >= M: break if count >= M: break # ================== prepare training data ============= inputs = torch.cat(input_pool, 0) labels = torch.cat(label_pool, 0) st_configs = { 'dataloader': to_generator([inputs, labels]), 'optimizer': student_optimizer, 'current_epoch': student_updates, 'total_epochs': 0, 'logger': logger, 'policy_step': teacher_updates } # ================= feed the selected batch ============ train_loss = student.fit(st_configs) training_loss_history.append(train_loss) student_updates += 1 student_lr_scheduler(student_optimizer, student_updates) # ================ test on dev set ===================== st_configs['dataloader'] = dev_dataloader acc, val_loss = student.val(st_configs) best_acc_on_dev = acc if best_acc_on_dev < acc else best_acc_on_dev logger.info( 'Stage [%d], Policy Steps: [%d] Test on Dev: Iteration [%d], accuracy: %5.4f, best: %5.4f, ' 'loss: %5.4f' % (0, teacher_updates, student_updates, acc, best_acc_on_dev, val_loss)) val_loss_history.append(val_loss) # ============== check if reach the expected accuracy or exceeds the max_t ================== if acc >= tau or i_tau == max_t: num_steps_to_achieve.append(i_tau) teacher_optimizer.zero_grad() reward = -math.log(i_tau / max_t) baseline = 0 if len( rewards) == 0 else 0.8 * baseline + 0.2 * reward last_reward = 0 if len(rewards) == 0 else rewards[-1] if last_reward >= reward: non_increasing_steps += 1 else: non_increasing_steps = 0 loss = -sum([torch.sum(_) for _ in actions]) * (reward - baseline) print('=' * 80) print(actions[0]) print('=' * 80) logger.info( 'Policy: Iterations [%d], stops at %d/%d to achieve %5.4f, loss: %5.4f, ' 'reward: %5.4f(%5.4f)' % (teacher_updates, i_tau, max_t, acc, loss.cpu().data[0], reward, baseline)) rewards.append(reward) loss.backward() teacher_optimizer.step() for name, param in teacher.named_parameters(): print(name, param) teacher_updates += 1 teacher_lr_scheduler(teacher_optimizer, teacher_updates) # ========= reinitialize the student network ========= overloaded_init_params(self.student_net) student_updates = 0 best_acc_on_dev = 0 print('Initialized the student net\'s parameters') # ========== break for next batch ==================== break # ==================== policy converged (stopping criteria) == if non_increasing_steps >= max_non_increasing_steps: torch.save({'num_steps_to_achieve': num_steps_to_achieve}, './tmp/curve_stage_%d.pth.tar' % 0) print(num_steps_to_achieve) return num_steps_to_achieve