import shutil import torch.backends.cudnn as cudnn #from utils import logging from torch.autograd import Variable from sklearn import metrics if __name__ == '__main__': FG = train_args() vis = Visdom(port=FG.vis_port, env=str(FG.vis_env)) vis.text(argument_report(FG, end='<br>'), win='config') # torch setting device = torch.device('cuda:{}'.format(FG.devices[0])) torch.cuda.set_device(FG.devices[0]) timer = SimpleTimer() FG.save_dir = str(FG.vis_env) if not os.path.exists(FG.save_dir): os.makedirs(FG.save_dir) printers = dict(lr=Scalar(vis, 'lr', opts=dict(showlegend=True, title='lr', ytickmin=0, ytinkmax=2.0)), D_loss=Scalar(vis, 'D_loss', opts=dict(showlegend=True, title='D loss',
def main(): # option flags FLG = train_args() # torch setting device = torch.device('cuda:{}'.format(FLG.devices[0])) torch.backends.cudnn.benchmark = True torch.cuda.set_device(FLG.devices[0]) # create summary and report the option visenv = FLG.model summary = Summary(port=39199, env=visenv) summary.viz.text(argument_report(FLG, end='<br>'), win='report' + str(FLG.running_fold)) train_report = ScoreReport() valid_report = ScoreReport() timer = SimpleTimer() fold_str = 'fold' + str(FLG.running_fold) best_score = dict(epoch=0, loss=1e+100, accuracy=0) #### create dataset ### # kfold split target_dict = np.load(pjoin(FLG.data_root, 'target_dict.pkl')) trainblock, validblock, ratio = fold_split( FLG.fold, FLG.running_fold, FLG.labels, np.load(pjoin(FLG.data_root, 'subject_indices.npy')), target_dict) def _dataset(block, transform): return ADNIDataset(FLG.labels, pjoin(FLG.data_root, FLG.modal), block, target_dict, transform=transform) # create train set trainset = _dataset(trainblock, transform_presets(FLG.augmentation)) # create normal valid set validset = _dataset( validblock, transform_presets('nine crop' if FLG.augmentation == 'random crop' else 'no augmentation')) # each loader trainloader = DataLoader(trainset, batch_size=FLG.batch_size, shuffle=True, num_workers=4, pin_memory=True) validloader = DataLoader(validset, num_workers=4, pin_memory=True) # data check # for image, _ in trainloader: # summary.image3d('asdf', image) # create model def kaiming_init(tensor): return kaiming_normal_(tensor, mode='fan_out', nonlinearity='relu') if 'plane' in FLG.model: model = Plane(len(FLG.labels), name=FLG.model, weights_initializer=kaiming_init) elif 'resnet11' in FLG.model: model = resnet11(len(FLG.labels), FLG.model, weights_initializer=kaiming_init) elif 'resnet19' in FLG.model: model = resnet19(len(FLG.labels), FLG.model, weights_initializer=kaiming_init) elif 'resnet35' in FLG.model: model = resnet35(len(FLG.labels), FLG.model, weights_initializer=kaiming_init) elif 'resnet51' in FLG.model: model = resnet51(len(FLG.labels), FLG.model, weights_initializer=kaiming_init) else: raise NotImplementedError(FLG.model) print_model_parameters(model) model = torch.nn.DataParallel(model, FLG.devices) model.to(device) # criterion train_criterion = torch.nn.CrossEntropyLoss(weight=torch.Tensor( list(map(lambda x: x * 2, reversed(ratio))))).to(device) valid_criterion = torch.nn.CrossEntropyLoss().to(device) # TODO resume # optimizer optimizer = torch.optim.Adam(model.parameters(), lr=FLG.lr, weight_decay=FLG.l2_decay) # scheduler scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, FLG.lr_gamma) start_epoch = 0 global_step = start_epoch * len(trainloader) pbar = None for epoch in range(1, FLG.max_epoch + 1): timer.tic() scheduler.step() summary.scalar('lr', fold_str, epoch - 1, optimizer.param_groups[0]['lr'], ytickmin=0, ytickmax=FLG.lr) # train() torch.set_grad_enabled(True) model.train(True) train_report.clear() if pbar is None: pbar = tqdm(total=len(trainloader) * FLG.validation_term, desc='Epoch {:<3}-{:>3} train'.format( epoch, epoch + FLG.validation_term - 1)) for images, targets in trainloader: images = images.cuda(device, non_blocking=True) targets = targets.cuda(device, non_blocking=True) optimizer.zero_grad() outputs = model(images) loss = train_criterion(outputs, targets) loss.backward() optimizer.step() train_report.update_true(targets) train_report.update_score(F.softmax(outputs, dim=1)) summary.scalar('loss', 'train ' + fold_str, global_step / len(trainloader), loss.item(), ytickmin=0, ytickmax=1) pbar.update() global_step += 1 if epoch % FLG.validation_term != 0: timer.toc() continue pbar.close() # valid() torch.set_grad_enabled(False) model.eval() valid_report.clear() pbar = tqdm(total=len(validloader), desc='Epoch {:>3} valid'.format(epoch)) for images, targets in validloader: true = targets npatchs = 1 if len(images.shape) == 6: _, npatchs, c, x, y, z = images.shape images = images.view(-1, c, x, y, z) targets = torch.cat([targets for _ in range(npatchs)]).squeeze() images = images.cuda(device, non_blocking=True) targets = targets.cuda(device, non_blocking=True) output = model(images) loss = valid_criterion(output, targets) valid_report.loss += loss.item() if npatchs == 1: score = F.softmax(output, dim=1) else: score = torch.mean(F.softmax(output, dim=1), dim=0, keepdim=True) valid_report.update_true(true) valid_report.update_score(score) pbar.update() pbar.close() # report vloss = valid_report.loss / len(validloader) summary.scalar('accuracy', 'train ' + fold_str, epoch, train_report.accuracy, ytickmin=-0.05, ytickmax=1.05) summary.scalar('loss', 'valid ' + fold_str, epoch, vloss, ytickmin=0, ytickmax=0.8) summary.scalar('accuracy', 'valid ' + fold_str, epoch, valid_report.accuracy, ytickmin=-0.05, ytickmax=1.05) is_best = False if best_score['loss'] > vloss: best_score['loss'] = vloss best_score['epoch'] = epoch best_score['accuracy'] = valid_report.accuracy is_best = True print('Best Epoch {}: validation loss {} accuracy {}'.format( best_score['epoch'], best_score['loss'], best_score['accuracy'])) # save if isinstance(model, torch.nn.DataParallel): state_dict = model.module.state_dict() else: state_dict = model.state_dict() save_checkpoint( dict(epoch=epoch, best_score=best_score, state_dict=state_dict, optimizer_state_dict=optimizer.state_dict()), FLG.checkpoint_root, FLG.running_fold, FLG.model, is_best) pbar = None timer.toc() print('Time elapse {}h {}m {}s'.format(*timer.total()))
def __init__(self, FG, SUPERVISED=True): # parameters self.num_epoch = FG.num_epoch self.batch_size = FG.batch_size self.save_dir = FG.save_dir self.result_dir = FG.result_dir self.dataset = 'MRI' self.log_dir = FG.log_dir self.model_name = 'infoGAN' self.input_size = FG.input_size self.z_dim = FG.z self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code self.len_discrete_code = 10 # categorical distribution (i.e. label) self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) self.sample_num = self.len_discrete_code ** 2 # torch setting self.device = torch.device('cuda:{}'.format(FG.devices[0])) torch.cuda.set_device(FG.devices[0]) timer = SimpleTimer() # load dataset x, y = Trainset(FG) # x = image, y=target trainset = ADNIDataset(FG, x, y, cropping=NineCrop((40,40,40),(32,32,32)), transform=Compose([Lambda(lambda patches: torch.stack([ToTensor()(patch) for patch in patches]))])) self.trainloader = DataLoader(trainset, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=4) #self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) #data = self.trainloader for _, data in enumerate(self.trainloader): data = data['image'] break # networks init self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code).to('cuda:{}'.format(FG.devices[0])) self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code).to('cuda:{}'.format(FG.devices[0])) self.G_optimizer = optim.Adam(self.G.parameters(), lr=FG.lrG, betas=(FG.beta1, FG.beta2)) self.D_optimizer = optim.Adam(self.D.parameters(), lr=FG.lrD, betas=(FG.beta1, FG.beta2)) self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=FG.lrD, betas=(FG.beta1, FG.beta2)) if len(FG.devices) != 1: self.G = torch.nn.DataParallel(self.G, FG.devices) self.D = torch.nn.DataParallel(self.D, FG.devices) self.BCE_loss = nn.BCELoss().to('cuda:{}'.format(FG.devices[0])) self.CE_loss = nn.CrossEntropyLoss().to('cuda:{}'.format(FG.devices[0])) self.MSE_loss = nn.MSELoss().to('cuda:{}'.format(FG.devices[0])) print('---------- Networks architecture -------------') ori_utils.print_network(self.G) ori_utils.print_network(self.D) print('-----------------------------------------------') # fixed noise & condition self.sample_z = torch.zeros((self.sample_num, self.z_dim)) for i in range(self.len_discrete_code): self.sample_z[i * self.len_discrete_code] = torch.rand(1, self.z_dim) for j in range(1, self.len_discrete_code): self.sample_z[i * self.len_discrete_code + j] = self.sample_z[i * self.len_discrete_code] temp = torch.zeros((self.len_discrete_code, 1)) for i in range(self.len_discrete_code): temp[i, 0] = i temp_y = torch.zeros((self.sample_num, 1)) for i in range(self.len_discrete_code): temp_y[i * self.len_discrete_code: (i + 1) * self.len_discrete_code] = temp self.sample_y = torch.zeros((self.sample_num, self.len_discrete_code)).scatter_(1, temp_y.type(torch.LongTensor), 1) self.sample_c = torch.zeros((self.sample_num, self.len_continuous_code)) # manipulating two continuous code #self.sample_z2 = torch.rand((1, self.z_dim)).expand(self.sample_num, self.z_dim) self.sample_z2 = torch.zeros((self.sample_num, self.z_dim)) z2 = torch.rand(1, self.z_dim) for i in range(self.sample_num): self.sample_z2[i] = z2 self.sample_y2 = torch.zeros(self.sample_num, self.len_discrete_code) self.sample_y2[:, 0] = 1 temp_c = torch.linspace(-1, 1, 10) self.sample_c2 = torch.zeros((self.sample_num, 2)) for i in range(self.len_discrete_code): for j in range(self.len_discrete_code): self.sample_c2[i*self.len_discrete_code+j, 0] = temp_c[i] self.sample_c2[i*self.len_discrete_code+j, 1] = temp_c[j] self.sample_z = self.sample_z.cuda(self.device, non_blocking=True) self.sample_y = self.sample_y.cuda(self.device, non_blocking=True) self.sample_c = self.sample_c.cuda(self.device, non_blocking=True) self.sample_z2 = self.sample_z2.cuda(self.device, non_blocking=True) self.sample_y2 = self.sample_y2.cuda(self.device, non_blocking=True) self.sample_c2 = self.sample_c2.cuda(self.device, non_blocking=True) vis = Visdom(port=10002, env=str(FG.vis_env)) self.printers = dict( D_loss = Scalar(vis, 'D_loss', opts=dict( showlegend=True, title='D loss', ytickmin=0, ytinkmax=2.0)), G_loss = Scalar(vis, 'G_loss', opts=dict( showlegend=True, title='G loss', ytickmin=0, ytinkmax=10)), info_loss = Scalar(vis, 'info_loss', opts=dict( showlegend=True, title='info loss', ytickmin=0, ytinkmax=10)), input = Image3D(vis, 'input'), input_fi = Image3D(vis, 'input_fi'), output = Image3D(vis, 'output'), output2 = Image3D(vis, 'output2')) self.timer = SimpleTimer()
out = out.reshape(out.size(0), -1) out = self.drop_out(out) out = self.fc1(out) out = self.fc2(out) return out if __name__ == '__main__': FG = train_args() vis = Visdom(port=FG.vis_port, env=str(FG.vis_env)) vis = Visdom(port=10002, env=str(FG.vis_env)) # torch setting device = torch.device('cuda:{}'.format(FG.devices[0])) torch.cuda.set_device(FG.devices[0]) timer = SimpleTimer() # Hyperparameters num_epochs = 100 num_classes = 10 batch_size = 100 learning_rate = 0.001 #batch_size = FG.batch_size save_dir = FG.save_dir input_size = 28 z_dim = 62 SUPERVISED = 'True' len_discrete_code = 10 len_continuous_code = 2 #sample_num = len_discrete_code ** 2 sample_num = 100
class infoGAN(object): def __init__(self, FG, SUPERVISED=True): # parameters self.num_epoch = FG.num_epoch self.batch_size = FG.batch_size self.save_dir = FG.save_dir self.result_dir = FG.result_dir self.dataset = 'MRI' self.log_dir = FG.log_dir self.model_name = 'infoGAN' self.input_size = FG.input_size self.z_dim = FG.z self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code self.len_discrete_code = 10 # categorical distribution (i.e. label) self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) self.sample_num = self.len_discrete_code ** 2 # torch setting self.device = torch.device('cuda:{}'.format(FG.devices[0])) torch.cuda.set_device(FG.devices[0]) timer = SimpleTimer() # load dataset x, y = Trainset(FG) # x = image, y=target trainset = ADNIDataset(FG, x, y, cropping=NineCrop((40,40,40),(32,32,32)), transform=Compose([Lambda(lambda patches: torch.stack([ToTensor()(patch) for patch in patches]))])) self.trainloader = DataLoader(trainset, batch_size=self.batch_size, shuffle=True, pin_memory=True, num_workers=4) #self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size) #data = self.trainloader for _, data in enumerate(self.trainloader): data = data['image'] break # networks init self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code).to('cuda:{}'.format(FG.devices[0])) self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size, len_discrete_code=self.len_discrete_code, len_continuous_code=self.len_continuous_code).to('cuda:{}'.format(FG.devices[0])) self.G_optimizer = optim.Adam(self.G.parameters(), lr=FG.lrG, betas=(FG.beta1, FG.beta2)) self.D_optimizer = optim.Adam(self.D.parameters(), lr=FG.lrD, betas=(FG.beta1, FG.beta2)) self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=FG.lrD, betas=(FG.beta1, FG.beta2)) if len(FG.devices) != 1: self.G = torch.nn.DataParallel(self.G, FG.devices) self.D = torch.nn.DataParallel(self.D, FG.devices) self.BCE_loss = nn.BCELoss().to('cuda:{}'.format(FG.devices[0])) self.CE_loss = nn.CrossEntropyLoss().to('cuda:{}'.format(FG.devices[0])) self.MSE_loss = nn.MSELoss().to('cuda:{}'.format(FG.devices[0])) print('---------- Networks architecture -------------') ori_utils.print_network(self.G) ori_utils.print_network(self.D) print('-----------------------------------------------') # fixed noise & condition self.sample_z = torch.zeros((self.sample_num, self.z_dim)) for i in range(self.len_discrete_code): self.sample_z[i * self.len_discrete_code] = torch.rand(1, self.z_dim) for j in range(1, self.len_discrete_code): self.sample_z[i * self.len_discrete_code + j] = self.sample_z[i * self.len_discrete_code] temp = torch.zeros((self.len_discrete_code, 1)) for i in range(self.len_discrete_code): temp[i, 0] = i temp_y = torch.zeros((self.sample_num, 1)) for i in range(self.len_discrete_code): temp_y[i * self.len_discrete_code: (i + 1) * self.len_discrete_code] = temp self.sample_y = torch.zeros((self.sample_num, self.len_discrete_code)).scatter_(1, temp_y.type(torch.LongTensor), 1) self.sample_c = torch.zeros((self.sample_num, self.len_continuous_code)) # manipulating two continuous code #self.sample_z2 = torch.rand((1, self.z_dim)).expand(self.sample_num, self.z_dim) self.sample_z2 = torch.zeros((self.sample_num, self.z_dim)) z2 = torch.rand(1, self.z_dim) for i in range(self.sample_num): self.sample_z2[i] = z2 self.sample_y2 = torch.zeros(self.sample_num, self.len_discrete_code) self.sample_y2[:, 0] = 1 temp_c = torch.linspace(-1, 1, 10) self.sample_c2 = torch.zeros((self.sample_num, 2)) for i in range(self.len_discrete_code): for j in range(self.len_discrete_code): self.sample_c2[i*self.len_discrete_code+j, 0] = temp_c[i] self.sample_c2[i*self.len_discrete_code+j, 1] = temp_c[j] self.sample_z = self.sample_z.cuda(self.device, non_blocking=True) self.sample_y = self.sample_y.cuda(self.device, non_blocking=True) self.sample_c = self.sample_c.cuda(self.device, non_blocking=True) self.sample_z2 = self.sample_z2.cuda(self.device, non_blocking=True) self.sample_y2 = self.sample_y2.cuda(self.device, non_blocking=True) self.sample_c2 = self.sample_c2.cuda(self.device, non_blocking=True) vis = Visdom(port=10002, env=str(FG.vis_env)) self.printers = dict( D_loss = Scalar(vis, 'D_loss', opts=dict( showlegend=True, title='D loss', ytickmin=0, ytinkmax=2.0)), G_loss = Scalar(vis, 'G_loss', opts=dict( showlegend=True, title='G loss', ytickmin=0, ytinkmax=10)), info_loss = Scalar(vis, 'info_loss', opts=dict( showlegend=True, title='info loss', ytickmin=0, ytinkmax=10)), input = Image3D(vis, 'input'), input_fi = Image3D(vis, 'input_fi'), output = Image3D(vis, 'output'), output2 = Image3D(vis, 'output2')) self.timer = SimpleTimer() def train(self): self.train_hist = {} self.train_hist['D_loss'] = [] self.train_hist['G_loss'] = [] self.train_hist['info_loss'] = [] self.train_hist['per_epoch_time'] = [] self.train_hist['total_time'] = [] self.y_real, self.y_fake = torch.ones(self.batch_size, 1).cuda(self.device, non_blocking=True), torch.zeros(self.batch_size, 1).cuda(self.device, non_blocking=True) self.D.train() print('training start!!') start_time = time.time() for epoch in range(self.num_epoch): self.timer.tic() torch.set_grad_enabled(True) self.G.train() epoch_start_time = time.time() for iter, data in enumerate(self.trainloader): image = data['image'] y = data['target'] # if iter == len(self.trainloader) // self.batch_size: # break z = torch.rand((self.batch_size, self.z_dim)) if self.SUPERVISED == True: y_disc = torch.zeros((self.batch_size, self.len_discrete_code)).scatter_(1, y.type(torch.LongTensor).unsqueeze(1), 1) else: y_disc = torch.from_numpy( np.random.multinomial(1, self.len_discrete_code * [float(1.0 / self.len_discrete_code)], size=[self.batch_size])).type(torch.FloatTensor) y_cont = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor) z, y_disc, y_cont = z.cuda(self.device, non_blocking=True), \ y_disc.cuda(self.device, non_blocking=True),\ y_cont.cuda(self.device, non_blocking=True) ci = (np.array(image.shape)/2).astype(int) x = image[:,:,0,:,:,ci[5]].contiguous().cuda(self.device, non_blocking=True) flipped = [] for cim in image: fi = np.flip(cim, 0) #fi = torch.from_numpy(fi.copy()).float() flipped += [fi] fi = np.flip(cim, 1) flipped += [fi] for i in range(len(flipped)): flipped[i] = torch.from_numpy(flipped[i].copy()).float() flipped = torch.stack(flipped) flipped = flipped[:,:,0,:,:,ci[5]].contiguous().cuda(self.device, non_blocking=True) """################ Update D network ################""" self.D_optimizer.zero_grad() self.printers['input']('input', x[1,:,:,:]) D_real, _, _ = self.D(x) if D_real.shape[0] == 1: break batch_size = D_real.size(0) self.y_real.resize_(batch_size).fill_(1) D_real_loss = self.BCE_loss(D_real, self.y_real) #print(D_real.shape, self.y_real.shape) self.printers['input_fi']('input_fi', flipped[1,:,:,:]) D_real, _, _ = self.D(flipped) if D_real.shape[0] == 1: break batch_size = D_real.size(0) self.y_real.resize_(batch_size).fill_(1) D_real_loss = self.BCE_loss(D_real, self.y_real) fake = self.G(z, y_cont, y_disc) D_fake, _, _ = self.D(fake) batch_size = D_fake.size(0) self.y_fake.resize_(batch_size).fill_(0) D_fake_loss = self.BCE_loss(D_fake, self.y_fake) D_loss = D_real_loss + D_fake_loss self.train_hist['D_loss'].append(D_loss.item()) D_loss.backward(retain_graph=True) self.D_optimizer.step() """################ Update G network ################""" self.G_optimizer.zero_grad() fake = self.G(z, y_cont, y_disc) #print(fake.shape) D_fake, D_cont, D_disc = self.D(fake) batch_size = D_fake.size(0) #print(D_fake.shape, fake.shape) #self.y_real.resize_(batch_size).fill_(1) y_real = torch.ones(batch_size).cuda(self.device, non_blocking=True) #print(D_fake.shape, self.y_real.shape) G_loss = self.BCE_loss(D_fake, y_real) self.train_hist['G_loss'].append(G_loss.item()) #with torch.enable_grad(): G_loss.backward(retain_graph=True) self.G_optimizer.step() # information loss disc_loss = self.CE_loss(D_disc, torch.max(y_disc, 1)[1]) cont_loss = self.MSE_loss(D_cont, y_cont) info_loss = disc_loss + cont_loss self.train_hist['info_loss'].append(info_loss.item()) with torch.enable_grad(): info_loss.backward() self.info_optimizer.step() self.printers['D_loss']('D_fake_loss', epoch+iter/len(self.trainloader), D_fake_loss) self.printers['D_loss']('D_real_loss', epoch+iter/len(self.trainloader), D_real_loss) self.printers['D_loss']('D_loss', epoch+iter/len(self.trainloader), D_loss) self.printers['G_loss']('G_loss', epoch+iter/len(self.trainloader), G_loss) self.printers['info_loss']('disc_loss', epoch+iter/len(self.trainloader), disc_loss) self.printers['info_loss']('cont_loss', epoch+iter/len(self.trainloader), cont_loss) self.printers['info_loss']('info_loss', epoch+iter/len(self.trainloader), info_loss) fake = fake[1,:,:,:] self.printers['output']('ori_output', fake) self.printers['output2']('output2', fake) if ((iter + 1) % 10) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, info_loss: %.8f" % ((epoch + 1), (iter + 1), len(self.trainloader) // self.batch_size, D_loss.item(), G_loss.item(), info_loss.item())) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) # if ((epoch) % 10) == 0: # with torch.no_grad(): # self.visualize_results((epoch+1)) # self.loss_plot(self.train_hist, # os.path.join(self.save_dir, self.dataset, self.model_name), # self.model_name) self.timer.toc() print('Time elapse {}h {}m {}s'.format(*self.timer.total())) self.train_hist['total_time'].append(time.time() - start_time) print("Avg one epoch time: %.2f, total %d epochs time: %.2f" %(np.mean(self.train_hist['per_epoch_time']), self.epoch, self.train_hist['total_time'][0])) print("Training finish!... save training results") self.save() ori_utils.generate_animation(self.result_dir+'/'+self.dataset+'/' +self.model_name+'/'+self.model_name, self.num_epoch) ori_utils.generate_animation(self.result_dir+'/'+self.dataset+'/' +self.model_name+'/'+self.model_name+'_cont', self.num_epoch) self.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) def visualize_results(self, epoch): self.G.eval() if not os.path.exists(self.result_dir+'/'+self.dataset+'/'+self.model_name): os.makedirs(self.result_dir+'/'+self.dataset+'/'+self.model_name) image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) """ style by class """ samples = self.G(self.sample_z, self.sample_c, self.sample_y) samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) #samples.shape = 100,32,32,9 samples = (samples + 1) / 2 ori_utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :,:], [image_frame_dim, image_frame_dim], self.save_dir+'/info-2D/disc_epoch%03d'%epoch+'.png') """ manipulating two continous codes """ samples = self.G(self.sample_z2, self.sample_c2, self.sample_y) samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) samples = (samples + 1) / 2 ori_utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], self.save_dir+'/info-2D/cont_epoch%03d'%epoch+'.png') def save(self): save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(self.G.state_dict(), os.path.join(save_dir,self.model_name + '_G.pkl')) torch.save(self.D.state_dict(), os.path.join(save_dir,self.model_name + '_D.pkl')) with open(os.path.join(save_dir,self.model_name+'_history.pkl'),'wb') as f: pickle.dump(self.train_hist, f) def load(self): save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) def loss_plot(self, hist, path='Train_hist.png', model_name=''): x = range(len(hist['D_loss'])) y1 = hist['D_loss'] y2 = hist['G_loss'] y3 = hist['info_loss'] plt.plot(x, y1, label='D_loss') plt.plot(x, y2, label='G_loss') plt.plot(x, y3, label='info_loss') plt.xlabel('Iter') plt.ylabel('Loss') plt.legend(loc=4) plt.grid(True) plt.tight_layout() path = os.path.join(path, model_name + '_loss.png') plt.savefig(path)