def train_loop(model, loader, test_loader, opt): device = torch.device('cuda:{}'.format(opt.cuda)) print(opt.exp) optim = torch.optim.Adam(model.parameters(), 5e-4, betas=(0.5, 0.999)) writer = SummaryWriter('tblog/%s' % opt.exp) for e in tqdm(range(opt.epochs)): losses = [] model.train() for (x, _) in tqdm(loader): x = x.to(device) if x.size(1) == 1: x = x.repeat(1, 3, 1, 1) x.requires_grad = False out = model(x) rec_err = (out - x)**2 loss = rec_err.mean() losses.append(loss.item()) optim.zero_grad() loss.backward() optim.step() losses = np.mean(losses) writer.add_scalar('rec_err', losses, e) writer.add_images('recons', torch.cat((x, out)).cpu() * 0.5 + 0.5, e) print('epochs:{}, recon error:{}'.format(e, losses)) torch.save(model.state_dict(), 'models/{}.pth'.format(opt.exp))
def train(self, data_loader, epochs=20, log_dir="runs/test/", log_freq=500): """ run training loop """ tb_logger = SummaryWriter(log_dir=log_dir) self.log_dir = log_dir constant_noise = torch.randn(64, self.latent_dim, device=self.device) # to be used for tensborboard-logging only for ep in range(1, epochs + 1): print("\n", "=" * 35, f"training epoch {ep}", "=" * 35, "\n") for it, (imgs, _) in tqdm(enumerate(data_loader)): self.glob_it += 1 imgs = imgs.to(self.device) enc_out = self.enc(imgs) mu, logvar = enc_out.chunk(chunks=2, dim=1) # reparamnetrizesation std = torch.exp(0.5 * logvar) e = torch.randn(std.shape, device=self.device) z = mu + e * std imgs_recon = self.dec(z) loss = self.ELBO(imgs_recon, imgs, mu, logvar) self.enc.zero_grad() self.dec.zero_grad() loss.backward() self.optim.step() tb_logger.add_scalar("train_loss", loss, self.glob_it) if self.glob_it % log_freq == 0: # log some images to tensorboard tb_logger.add_figure("samples", self.get_mXn_samples_grid(4, 4), self.glob_it) print( f"epoch {ep}, iter {it} (total iter {self.glob_it}): train_loss = {loss}" ) # per epoch logging print( f"epoch {ep}, iter {it} (total iter {self.glob_it}): train_loss = {loss}" ) tb_logger.add_images("epoch/sample1", self.sample(noise=constant_noise), ep) tb_logger.add_images("epoch/sample2", self.sample(num_images=64), ep) # save model at end of each epoch self.save_model(model_name=f"vae_ep{ep}.pt", idx=self.glob_it)
class TensorboardVisualizer: """ A wrapper class for tensorboardX. Note: Original tensorboardX API call is supported. """ def __init__(self, experiment=None, **kwargs): comment = '_' + str(experiment) if experiment is not None else '' if self.experiment_exists(experiment): self.has_writer = False raise ValueError( 'experiment [{}] already exists.'.format(experiment)) else: self.has_writer = True self.writer = SummaryWriter(comment=comment, **kwargs) def __del__(self): if self.has_writer: self.writer.close() @staticmethod def experiment_exists(experiment): if not mv.isdir('runs'): return False all_experiments = mv.listdir('runs') all_experiments = [e.split('_', 3)[-1] for e in all_experiments] if experiment in all_experiments: return True else: return False def plot(self, name, x, y, group='data'): tag = group + '/' + name self.writer.add_scalar(tag, y, x) def imshow(self, name, img_tensor, global_step=None, group='image'): """ Accept NCHW numpy.ndarray or torch tensor as input. """ imgs = mv.make_np(img_tensor) assert imgs.ndim == 4 if imgs.shape[1] == 1: imgs = np.concatenate([imgs, imgs, imgs], 1) tag = group + '/' + name self.writer.add_images(tag, imgs, global_step) def log(self, info, global_step=None, tag='text/log'): text_string = '[{time}] {info} <br>'.format( time=time.strftime('%y-%m-%d %H:%M:%S'), info=info) self.writer.add_text(tag, text_string, global_step) def __getattr__(self, name): return getattr(self.writer, name)
class MetricCounter: def __init__(self, exp_name=None): self.writer = SummaryWriter(exp_name) logging.basicConfig(filename='{}.log'.format(exp_name), level=logging.DEBUG) self.metrics = defaultdict(list) self.images = defaultdict(list) self.best_metric = 0 def add_image(self, x: np.ndarray, tag: str): self.images[tag].append(x) def clear(self): self.metrics = defaultdict(list) self.images = defaultdict(list) def add_losses(self, l_G, l_content, l_D=0): for name, value in zip( ('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'), (l_G, l_content, l_G - l_content, l_D)): self.metrics[name].append(value) def add_metrics(self, psnr, ssim): for name, value in zip(('PSNR', 'SSIM'), (psnr, ssim)): self.metrics[name].append(value) def loss_message(self): metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:])) for k in ('G_loss', 'PSNR', 'SSIM')) return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics)) def write_to_tensorboard(self, epoch_num, validation=False): scalar_prefix = 'Validation' if validation else 'Train' for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM', 'PSNR'): self.writer.add_scalar(f'{scalar_prefix}_{tag}', np.mean(self.metrics[tag]), global_step=epoch_num) for tag in self.images: imgs = self.images[tag] if imgs: imgs = np.array(imgs) self.writer.add_images(tag, imgs[:, :, :, ::-1].astype('float32') / 255, dataformats='NHWC', global_step=epoch_num) self.images[tag] = [] def update_best_model(self): cur_metric = np.mean(self.metrics['PSNR']) if self.best_metric < cur_metric: self.best_metric = cur_metric return True return False
def train(self): writer = SummaryWriter(log_dir="log_info") self.G.train() if self.opt.finetune: print("here") self.optm_G = optim.Adam(filter(lambda p:p.requires_grad, self.G.parameters()), lr = self.lr) train_loader = DataLoader( dataset=self.train_dataset, batch_size=self.opt.batch_size, num_workers=self.opt.n_threads, drop_last=True, shuffle=True ) keep_training = True epoch = 0 i = self.start_iter print("starting training") s_time = time.time() while keep_training: epoch += 1 print("epoch: {:d}".format(epoch)) for items in train_loader: i += 1 gt_images, gray_image, gt_edges, masks = self.cuda(*items) # masks = torch.cat([masks]*3, dim = 1) self.gray_image = gray_image masked_images = gt_images * masks masked_edges = gt_edges * masks[:,0:1,:,:] self.forward(masked_images, masks, masked_edges, gt_images, gt_edges) self.update_parameters() if i % self.opt.log_interval == 0: e_time = time.time() int_time = e_time - s_time print("epoch:{:d}, iteration:{:d}".format(epoch, i), ", l1_loss:", self.l1_loss/self.opt.log_interval, ", time_taken:", int_time) writer.add_scalars("loss_val", {"l1_loss":self.l1_loss*self.opt.batch_size/self.opt.log_interval, "D_loss":self.D_loss/self.opt.log_interval,"E_loss":self.E_loss*self.opt.batch_size/self.opt.log_interval}, i) masked_images = masked_images.cpu() fake_images = self.fake_B.cpu() fake_edges = self.edge_fake[1].cpu() fake_edges = torch.cat([fake_edges]*3, dim = 1) images = torch.cat([masked_images[0:3], fake_images[0:3], fake_edges[0:3]], dim = 0) writer.add_images("imgs", images, i) s_time = time.time() self.l1_loss = 0.0 self.D_loss = 0.0 self.E_loss = 0.0 if i % self.opt.save_interval == 0: save_ckpt('{:s}/ckpt/g_{:d}.pth'.format(self.opt.save_dir, i ), [('generator', self.G)], [('optimizer_G', self.optm_G)], i ) if self.have_D: save_ckpt('{:s}/ckpt/d_{:d}.pth'.format(self.opt.save_dir, i ), [('edge_D', self.edge_D)], [('optimizer_ED', self.optm_ED)], i ) writer.close()
class Visualizer(): def __init__(self, top_out_path ): # This will cause error in the very old train scripts self.writer = SummaryWriter(top_out_path) # |visuals|: dictionary of images to save def log_images(self, visuals, step): for label, image_numpy in visuals.items(): self.writer.add_images(label, [image_numpy], step) # scalars: dictionary of scalar labels and values def log_scalars(self, scalars, step, main_tag='metrics'): self.writer.add_scalars(main_tag=main_tag, tag_scalar_dict=scalars, global_step=step)
def forward(self, x): self.x1 = x if self.test: TezhengTuWriter = SummaryWriter('./runs/Pict') TezhengTuWriter.add_image('countdown_1', self.x1[0], global_step=0, dataformats='CHW') x = self.conv1(x) self.x2 = x if self.test: TezhengTuWriter.add_images('countdown_2', GYH(self.x2[0]), global_step=1, dataformats='NCHW') x = self.conv2(x) self.x3 = x if self.test: TezhengTuWriter.add_images('countdown_3', GYH(self.x3[0]), global_step=2, dataformats='NCHW') x = self.conv3(x) self.x4 = x if self.test: TezhengTuWriter.add_images('countdown_4', GYH(self.x4[0]), global_step=3, dataformats='NCHW') x = self.conv4(x) self.x5 = x if self.test: TezhengTuWriter.add_images('countdown_5', GYH(self.x5[0]), global_step=4, dataformats='NCHW') TezhengTuWriter.close() self.test = 0 x = x.view( x.size(0), -1) # flatten the output of conv2 to (batch_size, 32 * 7 * 7) output = self.out(x) return output, x # return x for visualization
def mark_pru(): net_e = load_encoder(after_f=True) net_e = nn.DataParallel(net_e).cuda() net_d = load_decoder(after_f=True) net_d = nn.DataParallel(net_d).cuda() data_loader = nyu_set.use_nyu_data(batch_s=1, max_len=400, isBenchmark=True) writer1 = SummaryWriter('/data/consistent_depth/gj_dir/benchmark_p2') with torch.no_grad(): num = 0 su = 0 for data, label in data_loader: num += 1 data = autograd.Variable(data.double().cuda(), requires_grad=False) prediction_d = net_d(net_e(data)) abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 = compute_depth_errors( label, prediction_d) writer1.add_images('pre', prediction_d, global_step=num) writer1.add_scalar('rmse', rmse, global_step=num) writer1.add_scalar("abs_rel", abs_rel, global_step=num) writer1.add_scalar('sq_rel', sq_rel, global_step=num) writer1.add_scalar('rmse_log', rmse_log, global_step=num) writer1.add_scalar('a1', a1, global_step=num) writer1.add_scalar('a2', a2, global_step=num) writer1.add_scalar('a3', a3, global_step=num) writer1.add_images('label', label, global_step=num) su += a3.item() print(su / num) # scaled_disp, _ = disp_to_depth(disp, 0.1, 10) # Saving colormapped depth image # vmax = np.percentile(disp_resized_np, 95) writer1.close() print('-> Done!')
class TensorboardLogger(Logger): def __init__(self, path): self.writer = SummaryWriter(path) def log_image(self, name, data, step): self.writer.add_image(name, data, step) def log_image_batch(self, name, data, step): self.writer.add_images(name, data, step) def log_number(self, name, data, step): self.writer.add_scalar(name, data, step) def log_text(self, name, data, step): self.writer.add_text(name, data, step) def log_figure(self, name, data, step): self.writer.add_figure(name, data, step) def log_embedding(self, name, data, step): self.writer.add_embedding(name, data, step)
class Logger(): def __init__(self, log_path='log'): self.logger = SummaryWriter(log_path) self.add_images_maxnum = 10 def add_scalar(self, tag, value, step): if isinstance(value, dict): for k, v in value.items(): self.logger.add_scalar(k, v.mean().item(), step) else: self.logger.add_scalar(tag, value, step) def add_image(self, tag, value, step, dataformats='CHW'): self.logger.add_image(tag, value, step) def add_images(self, tag, value, step, dataformats='CHW'): value = value[:self.add_images_maxnum] self.logger.add_images(tag, value, step) def add_graph(self, model, input_to_model=None): if input_to_model is None: input_to_model = torch.zeros((256, 3, 32, 128)) self.add_graph(model, input_to_model)
def benchmark_pruned(): net = load_t_net(file=True) #net = load_pru_mod(after_finetune=True).double() net = nn.DataParallel(net) net = net.cuda() data_loader = nyu_set.use_nyu_data(batch_s=4, max_len=100, isBenchmark=True) writer1 = SummaryWriter('./gj_dir/benchmark_t_mod') Joint = JointLoss(opt=None).double().cuda() criterion = nn.MSELoss(reduction='mean').cuda() net.eval() num = 0 for data, label in data_loader: num += 1 target = label2target(label) images = autograd.Variable(images.double().cuda(), requires_grad=False) prediction_d = net.forward(images)[0] # 0is depth .1 is confidence e_rmse = Joint.compute_rmse_error(prediction_d, target) e_rel = Joint.compute_l1_rel_error(prediction_d, target) loss = criterion(prediction_d, target["depth_gt"]) writer1.add_images('pre', prediction_d, global_step=num) writer1.add_scalar('rmse', e_rmse, global_step=num) writer1.add_scalar("rel", e_rel, global_step=num) writer1.add_scalar('loss', loss, global_step=num) writer1.add_images('label', label, global_step=num) print("ok")
class tf_recorder: def __init__(self, network_name, log_dir): os.system('mkdir -p {}'.format(log_dir)) for i in range(1000): self.targ = os.path.join(log_dir, '{}_{}'.format(network_name, i)) if not os.path.exists(self.targ): self.writer = SummaryWriter(self.targ) break def renew(self, subname): self.writer = SummaryWriter('{}_{}'.format(self.targ, subname)) self.niter = 0 def add_scalar(self, index, val): self.writer.add_scalar(index, val, self.niter) def add_scalars(self, index, group_dict): self.writer.add_scalar(index, group_dict, self.niter) def add_images(self, tag, images): self.writer.add_images(tag, images, self.niter) def iter(self, tick=1): self.niter += tick
def train_loop(model, loader, test_loader, opt): device = torch.device('cuda:{}'.format(opt.cuda)) print(opt.exp) optim = torch.optim.Adam(model.parameters(), 5e-4, betas=(0.5, 0.999)) writer = SummaryWriter('log/%s' % opt.exp) for e in tqdm(range(opt.epochs)): l1s, l2s = [], [] model.train() for (x, _) in tqdm(loader): x = x.to(device) x.requires_grad = False if not opt.u: out = model(x) rec_err = (out - x)**2 loss = rec_err.mean() l1s.append(loss.item()) else: mean, logvar = model(x) rec_err = (mean - x)**2 loss1 = torch.mean(torch.exp(-logvar) * rec_err) loss2 = torch.mean(logvar) loss = loss1 + loss2 l1s.append(rec_err.mean().item()) l2s.append(loss2.item()) optim.zero_grad() loss.backward() optim.step() auc = test_for_xray(opt, model, test_loader) if not opt.u: l1s = np.mean(l1s) writer.add_scalar('auc', auc, e) writer.add_scalar('rec_err', l1s, e) writer.add_images('recons', torch.cat((x, out)).cpu() * 0.5 + 0.5, e) print('epochs:{}, recon error:{}'.format(e, l1s)) else: l1s = np.mean(l1s) l2s = np.mean(l2s) writer.add_scalar('auc', auc, e) writer.add_scalar('rec_err', l1s, e) writer.add_scalar('logvars', l2s, e) writer.add_images('recons', torch.cat((x, mean)).cpu() * 0.5 + 0.5, e) writer.add_images('vars', torch.cat((x * 0.5 + 0.5, logvar.exp())).cpu(), e) print('epochs:{}, recon error:{}, logvars:{}'.format(e, l1s, l2s)) torch.save(model.state_dict(), './models/{}.pth'.format(opt.exp))
def benchmark_pruned(): net = load_pru_mod(after_finetune=True).double() net = nn.DataParallel(net) net = net.cuda() data_loader = nyu_set.use_nyu_data(batch_s=1, max_len=100, isBenchmark=True) writer1 = SummaryWriter('./gj_dir/benchmark_pru_mod') criterion = nn.MSELoss(reduction='mean').cuda() net.eval() num = 0 for data, label in data_loader: num += 1 images = Variable(images).double().cuda() label = Variable(label).double().cuda() # Reshape ...CHW -> XCHW shape = images.shape prediction_d = net.forward(images)[0] # 0is depth .1 is confidence out_shape = shape[:-3] + prediction_d.shape[-2:] prediction_d = prediction_d.reshape(out_shape) prediction_d = torch.exp(prediction_d) depth = prediction_d.squeeze(-3) depth = depth.detach().cpu().numpy().squeeze() inv_depth = (1.0 / depth) error = criterion(inv_depth, label).item() error = torch.sqrt(error / 2) writer1.add_scalar('loss', error, global_step=num) writer1.add_images('pre', prediction_d, global_step=num) writer1.add_images('label', label, global_step=num) writer1.add_images('process', inv_depth, global_step=num) print("ok")
tb_logger.add_image("GT heatmaps_{}".format(jj), hm[0].max(dim=0)[0].unsqueeze(0), global_step=global_step) # add predictions to TB tb_logger.add_image("attention maps", att[0], global_step=global_step) tb_preds = [] for jj, d in enumerate(det[0], 1): with torch.no_grad(): # d = d.clamp(0, 1) d = d.sigmoid() # d = d - d.min() # d /= d.max() tb_preds.append(d.unsqueeze(0)) tb_logger.add_images("detection maps", tb_preds, dataformats="CHW", global_step=global_step) # tb_logger.add_images("pred_stage_{}".format(jj), tb_preds, # global_step=global_step, # dataformats="CHW") # add weights and gradients histogram for name, param in student.named_parameters(): tb_logger.add_histogram(name + "_PARAMETERS", param.cpu().data.numpy(), global_step) if param.grad is not None: tb_logger.add_histogram(name + "_GRADIENTS", param.grad.cpu().data.numpy(), global_step) # # # if global_step % MINIVAL_EVERY_BATCHES == 0:
def main(args): # Enable cuda by default args.cuda = True # Define transforms normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) imagenet_mean = [0.485, 0.456, 0.406] imagenet_std = [0.229, 0.224, 0.225] transform = transforms.Compose( [transforms.Resize(args.image_size), transforms.ToTensor(), normalize]) # Create datasets datasets = { split: RGBDataset( os.path.join(args.dataset_root, split), seed=123, transform=transform, image_size=args.image_size, truncate_count=args.truncate_count, ) for split in ["train", "val", "test"] } # Create data loaders data_loaders = { split: DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=16) for split, dataset in datasets.items() } device = torch.device("cuda:0" if args.cuda else "cpu") # Create model net = FeatureNetwork() net.to(device) net.eval() # Generate image features for training images train_image_features = [] train_image_paths = [] for i, data in enumerate(data_loaders["train"], 0): # sample data inputs, input_paths = data inputs = {key: val.to(device) for key, val in inputs.items()} # Extract features with torch.no_grad(): feats = net(inputs["rgb"]) # (bs, 512) feats = feats.detach().cpu().numpy() train_image_features.append(feats) train_image_paths += input_paths["rgb"] train_image_features = np.concatenate(train_image_features, axis=0) # Generate image features for testing images test_image_features = [] test_image_paths = [] for i, data in enumerate(data_loaders["test"], 0): # sample data inputs, input_paths = data inputs = {key: val.to(device) for key, val in inputs.items()} # Extract features with torch.no_grad(): feats = net(inputs["rgb"]) # (bs, 512) feats = feats.detach().cpu().numpy() test_image_features.append(feats) test_image_paths += input_paths["rgb"] test_image_features = np.concatenate(test_image_features, axis=0) # (N, 512) # ================= Perform clustering ================== kmeans = MiniBatchKMeans( init="k-means++", n_clusters=args.num_clusters, batch_size=args.batch_size, n_init=10, max_no_improvement=20, verbose=0, ) save_h5_path = os.path.join(args.save_dir, f"clusters_{args.num_clusters:05d}_data.h5") if os.path.isfile(save_h5_path): print("========> Loading existing clusters!") h5file = h5py.File(os.path.join(save_h5_path), "r") train_cluster_centroids = np.array(h5file["cluster_centroids"]) kmeans.cluster_centers_ = train_cluster_centroids train_cluster_assignments = kmeans.predict( train_image_features) # (N, ) h5file.close() else: kmeans.fit(train_image_features) train_cluster_assignments = kmeans.predict( train_image_features) # (N, ) train_cluster_centroids = np.copy( kmeans.cluster_centers_) # (num_clusters, 512) # Create a dictionary of cluster -> images for visualization cluster2image = {} if args.visualize_clusters: log_dir = os.path.join( args.save_dir, f"train_clusters_#clusters{args.num_clusters:05d}") tbwriter = SummaryWriter(log_dir=log_dir) for i in range(args.num_clusters): valid_idxes = np.where(train_cluster_assignments == i)[0] valid_image_paths = [train_image_paths[j] for j in valid_idxes] # Shuffle and pick only upto 100 images per cluster random.shuffle(valid_image_paths) # Read the valid images valid_images = [] for path in valid_image_paths[:100]: img = cv2.resize( np.flip(cv2.imread(path), axis=2), (args.image_size, args.image_size), ) valid_images.append(img) valid_images = (np.stack(valid_images, axis=0).astype(np.float32) / 255.0) # (K, H, W, C) valid_images = torch.Tensor(valid_images).permute(0, 3, 1, 2).contiguous() cluster2image[i] = valid_images if args.visualize_clusters: # Write the train image clusters to tensorboard tbwriter.add_images(f"Cluster #{i:05d}", valid_images, 0) h5file = h5py.File( os.path.join(args.save_dir, f"clusters_{args.num_clusters:05d}_data.h5"), "a") if "cluster_centroids" not in h5file.keys(): h5file.create_dataset("cluster_centroids", data=train_cluster_centroids) for i in range(args.num_clusters): if f"cluster_{i}/images" not in h5file.keys(): h5file.create_dataset(f"cluster_{i}/images", data=cluster2image[i]) h5file.close() if args.visualize_clusters: # Dot product of test_image_features with train_cluster_centroids test_dot_centroids = np.matmul( test_image_features, train_cluster_centroids.T) # (N, num_clusters) if args.normalize_embedding: test_dot_centroids = (test_dot_centroids + 1.0) / 2.0 else: test_dot_centroids = F.softmax(torch.Tensor(test_dot_centroids), dim=1).numpy() # Find the top-K matching centroids topk_matches = np.argpartition(test_dot_centroids, -5, axis=1)[:, -5:] # (N, 5) # Write the test nearest neighbors to tensorboard tbwriter = SummaryWriter(log_dir=os.path.join( args.save_dir, f"test_neighbors_#clusters{args.num_clusters:05d}")) for i in range(100): test_image_path = test_image_paths[i] test_image = cv2.resize(cv2.imread(test_image_path), (args.image_size, args.image_size)) test_image = np.flip(test_image, axis=2).astype(np.float32) / 255.0 test_image = torch.Tensor(test_image).permute(2, 0, 1).contiguous() topk_clusters = topk_matches[i] # Pick some 4 images representative of a cluster topk_cluster_images = [] for k in topk_clusters: imgs = cluster2image[k][:4] # (4, C, H, W) if imgs.shape[0] == 0: continue elif imgs.shape[0] != 4: imgs_pad = torch.zeros(4 - imgs.shape[0], *imgs.shape[1:]) imgs = torch.cat([imgs, imgs_pad], dim=0) # Downsample by a factor of 2 imgs = F.interpolate(imgs, scale_factor=0.5, mode="bilinear") # (4, C, H/2, W/2) # Reshape to form a grid imgs = imgs.permute(1, 0, 2, 3) # (C, 4, H/2, W/2) C, _, Hby2, Wby2 = imgs.shape imgs = (imgs.view(C, 2, 2, Hby2, Wby2).permute( 0, 1, 3, 2, 4).contiguous().view(C, Hby2 * 2, Wby2 * 2)) # Draw a red border imgs[0, :4, :] = 1.0 imgs[1, :4, :] = 0.0 imgs[2, :4, :] = 0.0 imgs[0, -4:, :] = 1.0 imgs[1, -4:, :] = 0.0 imgs[2, -4:, :] = 0.0 imgs[0, :, :4] = 1.0 imgs[1, :, :4] = 0.0 imgs[2, :, :4] = 0.0 imgs[0, :, -4:] = 1.0 imgs[1, :, -4:] = 0.0 imgs[2, :, -4:] = 0.0 topk_cluster_images.append(imgs) vis_img = torch.cat([test_image, *topk_cluster_images], dim=2) image_name = f"Test image #{i:04d}" for k in topk_clusters: score = test_dot_centroids[i, k].item() image_name += f"_{score:.3f}" tbwriter.add_image(image_name, vis_img, 0)
class Trainer: def __init__(self, model, loss, train_loader, test_loader, args): self.model = model self.args = args self.args.start_epoch = 0 self.train_loader = train_loader self.test_loader = test_loader # Loss function and Optimizer self.loss = loss self.optimizer = self.get_optimizer() # Tensorboard Writer self.summary_writer = SummaryWriter(log_dir=args.summary_dir) # Model Loading if args.resume: self.load_checkpoint(self.args.resume_from) def train(self): self.model.train() for epoch in range(self.args.start_epoch, self.args.num_epochs): loss_list = [] print("epoch {}...".format(epoch)) for batch_idx, (data, _) in enumerate(tqdm(self.train_loader)): if self.args.cuda: data = data.cuda() data = Variable(data) self.optimizer.zero_grad() recon_batch, mu, logvar = self.model(data) loss = self.loss(recon_batch, data, mu, logvar) loss.backward() self.optimizer.step() loss_list.append(loss.item()) print("epoch {}: - loss: {}".format(epoch, np.mean(loss_list))) new_lr = self.adjust_learning_rate(epoch) print('learning rate:', new_lr) self.summary_writer.add_scalar('training/loss', np.mean(loss_list), epoch) self.summary_writer.add_scalar('training/learning_rate', new_lr, epoch) self.save_checkpoint({ 'epoch': epoch + 1, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), }) if epoch % self.args.test_every == 0: self.test(epoch) def test(self, cur_epoch): print('testing...') with torch.no_grad(): self.model.eval() test_loss = 0 for i, (data, _) in enumerate(self.test_loader): if self.args.cuda: data = data.cuda() recon_batch, mu, logvar = self.model(data) test_loss += self.loss(recon_batch, data, mu, logvar).item() _, indices = recon_batch.max(1) indices.data = indices.data.float() / 255 if i == 0: n = min(data.size(0), 8) comparison = torch.cat( [data[:n], indices.view(-1, 3, 32, 32)[:n]]) self.summary_writer.add_images('testing_set/image', comparison, cur_epoch) comparison = torchvision.utils.make_grid(comparison, nrow=6) torchvision.utils.save_image(comparison.cpu(), 'results/reconstruction_' + str(cur_epoch) + '.png', nrow=8) test_loss /= len(self.test_loader.dataset) print('====> Test set loss: {:.4f}'.format(test_loss)) self.summary_writer.add_scalar('testing/loss', test_loss, cur_epoch) self.model.train() def test_on_trainings_set(self): print('testing...') with torch.no_grad(): self.model.eval() test_loss = 0 for i, (data, _) in enumerate(self.train_loader): if self.args.cuda: data = data.cuda() recon_batch, mu, logvar = self.model(data) test_loss += self.loss(recon_batch, data, mu, logvar).item() _, indices = recon_batch.max(1) indices.data = indices.data.float() / 255 if i % 50 == 0: n = min(data.size(0), 8) comparison = torch.cat( [data[:n], indices.view(-1, 3, 32, 32)[:n]]) self.summary_writer.add_images('training_set/image', comparison, i) test_loss /= len(self.test_loader.dataset) print('====> Test on training set loss: {:.4f}'.format(test_loss)) self.model.train() def get_optimizer(self): return optim.Adam(self.model.parameters(), lr=self.args.learning_rate, weight_decay=self.args.weight_decay) def adjust_learning_rate(self, epoch): """Sets the learning rate to the initial LR multiplied by 0.98 every epoch""" learning_rate = self.args.learning_rate * ( self.args.learning_rate_decay**epoch) for param_group in self.optimizer.param_groups: param_group['lr'] = learning_rate return learning_rate def save_checkpoint(self, state, is_best=False, filename='checkpoint.pth.tar'): ''' a function to save checkpoint of the training :param state: {'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict()} :param is_best: boolean to save the checkpoint aside if it has the best score so far :param filename: the name of the saved file ''' torch.save(state, self.args.checkpoint_dir + filename) if is_best: shutil.copyfile(self.args.checkpoint_dir + filename, self.args.checkpoint_dir + 'model_best.pth.tar') def load_checkpoint(self, filename): filename = self.args.checkpoint_dir + filename try: print("Loading checkpoint '{}'".format(filename)) checkpoint = torch.load(filename) self.args.start_epoch = checkpoint['epoch'] self.model.load_state_dict(checkpoint['state_dict']) self.optimizer.load_state_dict(checkpoint['optimizer']) print("Checkpoint loaded successfully from '{}' at (epoch {})\n". format(self.args.checkpoint_dir, checkpoint['epoch'])) except: print("No checkpoint exists from '{}'. Skipping...\n".format( self.args.checkpoint_dir))
def train_pru_mod(epoch=100, batch=4, lr=0.001): #net = load_t_net().double() net = load_pru_mod(after_finetune=True).double() net = nn.DataParallel(net) net = net.cuda() train_Data = nyu_set.use_nyu_data(batch_s=batch, max_len=160, isBenchmark=False) writer1 = SummaryWriter('./gj_dir/train_pru_mod') criterion = nn.MSELoss(reduction='mean').cuda() Joint = JointLoss(opt=None).double().cuda() s_loss = ts_loss.SSIM().cuda() optimizer = optim.Adam(net.parameters(), lr=lr) net.train() import time for epoch in range(epoch): time_start = time.time() batch_size = batch for i, data in enumerate(train_Data): images, depths = data # images = autograd.Variable(inputs.cuda(), requires_grad=False) images = Variable(images).double().cuda() depths = Variable(depths).double().cuda() # labels = labels.to(device).double() optimizer.zero_grad() # debug_img = transforms.ToPILImage()(images[0,:,:,:].float().cpu()) # debug_img.save("debug.jpg") output_net = net(images)[0].double() # loss1 = 1 - s_loss.forward(output_s_features, T_mid_feature[0]) # loss2 = criterion(output_s_depth,output_t) loss1 = criterion(output_net, depths) loss2 = Joint.LaplacianSmoothnessLoss(output_net, images) loss3 = Joint.compute_image_aware_2nd_smoothness_cost( output_net, images) #loss4 = Joint.compute_image_aware_1st_smoothness_cost(output_net,images) loss4 = 1 - s_loss.forward(output_net, depths) loss = loss1 * 10 + loss2 + loss3 + loss4 loss.backward() optimizer.step() print('[%d, %5d] loss: %.4f A:%.4f B:%.4f C:%.4f D:%.4f' % (epoch + 1, (i + 1) * batch_size, loss.item(), loss1.item(), loss2.item(), loss3.item(), loss4.item())) writer1.add_scalar('loss', loss.item(), global_step=(epoch + 1) * batch_size + i) writer1.add_scalar('loss2', loss2.item(), global_step=(epoch + 1) * batch_size + i) #debug_img = transforms.ToPILImage()(output_net) writer1.add_images('pre', output_net, global_step=epoch) shape = images.shape dep = torch.exp(output_net) dep = dep.detach().cpu().numpy() inv_dep = 1.0 / dep * 255 writer1.add_images('pro-dep', inv_dep, global_step=epoch) writer1.add_images('labels', depths, global_step=epoch) torch.save(net.module, "./gj_dir/after_nyu.pth.tar") time_end = time.time() print('Time cost:', time_end - time_start, "s") print('Finished Training')
class PredictionCallback(tf.keras.callbacks.Callback): """Predictions logged using tensorboardX""" def __init__(self, model, logdir, val_generator, scaled_mask, binary_threshold=0.5, update_freq=1): super(PredictionCallback, self).__init__() self.val_generator = val_generator self.writer = SummaryWriter(logdir=logdir) self.scaled_mask = scaled_mask self.binary_threshold = binary_threshold self._model = model self.num_classes = self._model.output.shape.as_list()[-1] self.update_freq = update_freq def on_epoch_end(self, epoch, logs={}): if epoch == -1: epoch = 0 else: epoch = epoch + 1 if epoch % self.update_freq == 0: logger.debug("logging images to tensorboard, epoch=%d" % epoch) for input_batch, target_batch in self.val_generator: input_batch = input_batch.numpy() target_batch = target_batch.numpy() break # input_batch, target_batch = next(iter(self.val_generator.as_numpy_iterator())) if self.scaled_mask: target_batch = np.expand_dims(target_batch, axis=-1) # predict pred_batch = self._model.predict_on_batch(input_batch).numpy() predictions_on_inputs = masks.get_colored_segmentation_mask(pred_batch, self.num_classes, images=input_batch, binary_threshold=self.binary_threshold) self.writer.add_images('inputs/with_predictions', predictions_on_inputs, dataformats='NHWC', global_step=epoch) targets_on_inputs = masks.get_colored_segmentation_mask(target_batch, self.num_classes, images=input_batch, binary_threshold=self.binary_threshold) self.writer.add_images('inputs/with_targets', targets_on_inputs, dataformats='NHWC', global_step=epoch) targets_rgb = masks.get_colored_segmentation_mask(target_batch, self.num_classes, binary_threshold=self.binary_threshold, alpha=1.0) self.writer.add_images('targets/rgb', targets_rgb, dataformats='NHWC', global_step=epoch) pred_rgb = masks.get_colored_segmentation_mask(pred_batch, self.num_classes, binary_threshold=self.binary_threshold, alpha=1.0) self.writer.add_images('predictions/rgb', pred_rgb, dataformats='NHWC', global_step=epoch) if not self.scaled_mask: pred_batch = np.argmax(pred_batch, axis=-1).astype(np.float32) target_batch = np.argmax(target_batch, axis=-1).astype(np.float32) # reshape, that add_images works pred_batch = np.expand_dims(pred_batch, axis=-1) target_batch = np.expand_dims(target_batch, axis=-1) else: pred_batch[pred_batch > self.binary_threshold] = 1.0 pred_batch[pred_batch <= self.binary_threshold] = 0.0 self.writer.add_images('inputs', input_batch, dataformats='NHWC', global_step=epoch) self.writer.add_images('targets', target_batch, dataformats='NHWC', global_step=epoch) self.writer.add_images('predictions', pred_batch, dataformats='NHWC', global_step=epoch)
# Weighted sum of each loss into a total loss function reconstruction_loss = (1 - Lambda) * reconstruction_loss lstd_loss = Lambda * lstd_loss classifier_loss = (1 - Lambda) * classifier_loss total_loss = reconstruction_loss + lstd_loss + classifier_loss # Backprop Total loss (VAE + Classifer) and update VAE if training if phase == 'train': # Calculate Total loss and Backprop through network total_loss.backward() optim.step() if idx % 1000 == 0: writer.add_images("Images", inputs, global_step=epoch) writer.add_images("Reconstructions", recon_images, global_step=epoch) # writer.add_graph(VAE(), (inputs, labels), global_step=epoch) if idx % 100 == 0 and args.display: print( f'{phase} Batch Loss: {running_cls_loss}| Acc: {running_corrects / batch_size}' ) # (Classifer) Average Loss and Accuracy for current batch running_cls_loss += classifier_loss.item() * inputs.size(0) running_corrects += torch.sum(predicted == labels.data) # (VAE) Average Losses for current batch
def train( backbone, root_dir, train_index_fp, pretrain_model, optimizer, epochs=50, lr=0.001, wd=5e-4, momentum=0.9, batch_size=4, ctx=mx.cpu(), verbose_step=5, output_dir='ckpt', ): output_dir = os.path.join(output_dir, backbone) os.makedirs(output_dir, exist_ok=True) num_kernels = 3 dataset = StdDataset(root_dir=root_dir, train_idx_fp=train_index_fp, num_kernels=num_kernels - 1) if not isinstance(ctx, (list, tuple)): ctx = [ctx] batch_size = batch_size * len(ctx) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) net = PSENet(base_net_name=backbone, num_kernels=num_kernels, ctx=ctx, pretrained=True) # initial params net.initialize(mx.init.Xavier(), ctx=ctx) net.collect_params("extra_.*_weight|decoder_.*_weight").initialize( mx.init.Xavier(), ctx=ctx, force_reinit=True) net.collect_params("extra_.*_bias|decoder_.*_bias").initialize( mx.init.Zero(), ctx=ctx, force_reinit=True) if pretrain_model is not None: net.load_parameters(pretrain_model, ctx=ctx, allow_missing=True, ignore_extra=True) # pse_loss = DiceLoss(lam=0.7, num_kernels=num_kernels) pse_loss = DiceLoss_with_OHEM(lam=0.7, num_kernels=num_kernels, debug=False) # lr_scheduler = ls.PolyScheduler( # max_update=icdar_loader.length * epochs // batch_size, base_lr=lr # ) max_update = len(dataset) * epochs // batch_size lr_scheduler = ls.MultiFactorScheduler( base_lr=lr, step=[max_update // 3, max_update * 2 // 3], factor=0.1) optimizer_params = { 'learning_rate': lr, 'wd': wd, 'momentum': momentum, 'lr_scheduler': lr_scheduler, } if optimizer.lower() == 'adam': optimizer_params.pop('momentum') trainer = Trainer(net.collect_params(), optimizer=optimizer, optimizer_params=optimizer_params) summary_writer = SummaryWriter(output_dir) for e in range(epochs): cumulative_loss = 0 num_batches = 0 for i, item in enumerate(loader): item_ctxs = [split_and_load(field, ctx) for field in item] loss_list = [] for im, gt_text, gt_kernels, training_masks, ori_img in zip( *item_ctxs): gt_text = gt_text[:, ::4, ::4] gt_kernels = gt_kernels[:, :, ::4, ::4] training_masks = training_masks[:, ::4, ::4] with autograd.record(): kernels_pred = net(im) # 第0个是对complete text的预测 loss = pse_loss(gt_text, gt_kernels, kernels_pred, training_masks) loss_list.append(loss) mean_loss = [] for loss in loss_list: loss.backward() mean_loss.append(mx.nd.mean(to_cpu(loss)).asscalar()) mean_loss = np.mean(mean_loss) trainer.step(batch_size) if i % verbose_step == 0: global_steps = dataset.length * e + i * batch_size summary_writer.add_scalar('loss', mean_loss, global_steps) summary_writer.add_scalar( 'c_loss', mx.nd.mean(to_cpu(pse_loss.C_loss)).asscalar(), global_steps, ) summary_writer.add_scalar( 'kernel_loss', mx.nd.mean(to_cpu(pse_loss.kernel_loss)).asscalar(), global_steps, ) summary_writer.add_scalar('pixel_accuracy', pse_loss.pixel_acc, global_steps) if i % 1 == 0: logger.info( "step: {}, lr: {}, " "loss: {}, score_loss: {}, kernel_loss: {}, pixel_acc: {}, kernel_acc: {}" .format( i * batch_size, trainer.learning_rate, mean_loss, mx.nd.mean(to_cpu(pse_loss.C_loss)).asscalar(), mx.nd.mean(to_cpu(pse_loss.kernel_loss)).asscalar(), pse_loss.pixel_acc, pse_loss.kernel_acc, )) cumulative_loss += mean_loss num_batches += 1 summary_writer.add_scalar('mean_loss_per_epoch', cumulative_loss / num_batches, global_steps) logger.info("Epoch {}, mean loss: {}\n".format( e, cumulative_loss / num_batches)) net.save_parameters( os.path.join(output_dir, model_fn_prefix(backbone, e))) summary_writer.add_image('complete_gt', to_cpu(gt_text[0:1, :, :]), global_steps) summary_writer.add_image('complete_pred', to_cpu(kernels_pred[0:1, 0, :, :]), global_steps) summary_writer.add_images( 'kernels_gt', to_cpu(gt_kernels[0:1, :, :, :]).reshape(-1, 1, 0, 0), global_steps, ) summary_writer.add_images( 'kernels_pred', to_cpu(kernels_pred[0:1, 1:, :, :]).reshape(-1, 1, 0, 0), global_steps, ) summary_writer.close()
class Train: def __init__(self): self.epoch = 0 self.step = 0 def train(self): weight = torch.ones(2) criterion = criterion_CEloss(weight.cuda()) optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001,betas=(0.9,0.999)) lambda_lr = lambda epoch:(float)(self.args.max_epochs*len(self.dataset_train_loader)-self.step)/(float)(self.args.max_epochs*len(self.dataset_train_loader)) model_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda_lr) f_loss = open(pjoin(self.checkpoint_save,"loss.csv"),'w') loss_writer = csv.writer(f_loss) self.visual_writer = SummaryWriter(os.path.join(self.checkpoint_save,'logs')) loss_item = [] max_step = self.args.max_epochs * len(self.dataset_train_loader) _,w,h = self.dataset_test.get_random_image()[0].shape img_tbx = np.zeros((max_step//self.args.step_test, 3, w*2, h*2), dtype=np.uint8) while self.epoch < self.args.max_epochs: for step,(inputs_train,mask_train) in enumerate(tqdm(self.dataset_train_loader)): self.model.train() inputs_train = inputs_train.cuda() mask_train = mask_train.cuda() output_train = self.model(inputs_train) optimizer.zero_grad() self.loss = criterion(output_train, mask_train[:,0]) loss_item.append(self.loss) self.loss.backward() optimizer.step() self.step += 1 loss_writer.writerow([self.step,self.loss.item()]) self.visual_writer.add_scalar('loss',self.loss.item(),self.step) if self.args.step_test>0 and self.step % self.args.step_test == 0: print('testing...') self.model.eval() self.test(img_tbx) print('Loss for Epoch {}:{:.03f}'.format(self.epoch, sum(loss_item)/len(self.dataset_train_loader))) loss_item.clear() model_lr_scheduler.step() self.epoch += 1 if self.args.epoch_save>0 and self.epoch % self.args.epoch_save == 0: self.checkpoint() self.visual_writer.add_images('cd_test',img_tbx,0, dataformats='NCHW') f_loss.close() self.visual_writer.close() def test(self,img_tbx): _, _, w_r, h_r = img_tbx.shape w_r //= 2 h_r //= 2 input, mask_gt = self.dataset_test.get_random_image() input = input.view(1, -1, h_r, w_r) input = input.cuda() output = self.model(input) input = input[0].cpu().data img_t0 = input[0:3, :, :] img_t1 = input[3:6, :, :] img_t0 = (img_t0 + 1) * 128 img_t1 = (img_t1 + 1) * 128 output = output[0].cpu().data mask_pred = np.where(F.softmax(output[0:2, :, :], dim=0)[0] > 0.5, 0, 255) mask_gt = np.squeeze(np.where(mask_gt == True, 255, 0), axis=0) self.store_result(img_t0, img_t1, mask_gt, mask_pred,img_tbx) def store_result(self, t0, t1, mask_gt, mask_pred, img_save): _, _, w, h = img_save.shape w //=2 h //=2 i = self.step//self.args.step_test - 1 img_save[i, :, 0:w, 0:h] = t0.numpy().astype(np.uint8) img_save[i, :, 0:w, h:2 * h] = t1.numpy().astype(np.uint8) img_save[i, :, w:2 * w, 0:h] = np.transpose(cv2.cvtColor(mask_gt.astype(np.uint8), cv2.COLOR_GRAY2RGB),(2,0,1)).astype(np.uint8) img_save[i, :, w:2 * w, h:2 * h] = np.transpose(cv2.cvtColor(mask_pred.astype(np.uint8), cv2.COLOR_GRAY2RGB),(2,0,1)).astype(np.uint8) #img_save = np.transpose(img_save, (1, 0, 2)) def checkpoint(self): filename = '{:08d}.pth'.format(self.step) cp_path = pjoin(self.checkpoint_save,'checkpointdir') if not os.path.exists(cp_path): os.makedirs(cp_path) torch.save(self.model.state_dict(),pjoin(cp_path,filename)) print("Net Parameters in step:{:08d} were saved.".format(self.step)) def run(self): self.model = TANet(self.args.encoder_arch, self.args.local_kernel_size, self.args.attn_stride, self.args.attn_padding, self.args.attn_groups, self.args.drtam, self.args.refinement) if self.args.drtam: print('Dynamic Receptive Temporal Attention Network (DR-TANet)') else: print('Temporal Attention Network (TANet)') print('Encoder:' + self.args.encoder_arch) if self.args.refinement: print('Adding refinement...') if self.args.multi_gpu: self.model = nn.DataParallel(self.model).cuda() else: self.model = self.model.cuda() self.train()
def fit(self, train_loader_fn: Callable, epochs: int = 2, lr: float = 1e-3, n_critic: int = 5, disk_backup_filename: str = "dumped_weights.bin"): """ Trains this WGAN. :param train_loader_fn: loader-returning function to generate training data. :param epochs: number of epochs. An epoch is a full pass over the train loader. :param lr: learning rate. :param n_critic: number of steps to train the discriminator (a.k.a. critic) per each training step of the generator. In WGANs, it is OK to make it large. :param disk_backup_filename: filename to dump trainable parameters. Dumping is done once per epoch. """ Util.set_param_requires_grad(self.generator, True) Util.set_param_requires_grad(self.discriminator, True) if not self.params: self.random_init() self.save_params() self.save_params_to_disk(disk_backup_filename) else: self.restore_params() g_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=lr) d_optimizer = torch.optim.RMSprop(self.discriminator.parameters(), lr=lr) self.generator.train() self.discriminator.train() writer = SummaryWriter("gan_training") batch_index = 0 for epoch in range(epochs): data_sampler = iter(train_loader_fn()) while True: # preload real batches real_batches = [] for i in range(n_critic): try: real_data, _ = next(data_sampler) real_batches.append(real_data) except StopIteration: # for simplicity, omitting the last incomplete sequence of batches break if len(real_batches) != n_critic: # next epoch break batch_size = real_batches[0].shape[0] # train d_optimizer.zero_grad() for i in range(n_critic): real_data = real_batches[i] fake_data = self.generator( Util.conditional_to_cuda( torch.randn(batch_size, self.latent_dim))) loss1 = self.discriminator(real_data).mean() loss2 = self.discriminator(fake_data).mean() discriminator_loss = -(loss1 - loss2) discriminator_loss.backward() d_optimizer.step() g_optimizer.zero_grad() fake_data = self.generator( Util.conditional_to_cuda( torch.randn(batch_size, self.latent_dim))) generator_loss = -self.discriminator(fake_data).mean() #generator_loss = (fake_data - real_data).abs().mean() generator_loss.backward() g_optimizer.step() # eval with torch.no_grad(): writer.add_scalar("discriminator_loss", discriminator_loss.detach(), batch_index) writer.add_scalar("generator_loss", generator_loss.detach(), batch_index) writer.add_scalar("epoch", epoch, batch_index) if batch_index % 20 == 0: writer.add_images("generated_batch", fake_data.detach().clamp(-1, 1), batch_index) writer.add_images("real_batch", real_data.detach(), batch_index) batch_index += 1
def fit(model, device, data_path, epochs=5, batch_size=2, lr=0.001, val_percent=0.1): # get image/mask data dataset = IrisDataset(*data_path) n_valid = int(len(dataset) * val_percent) n_train = len(dataset) - n_valid train_ds, valid_ds = random_split(dataset, [n_train, n_valid]) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) valid_dl = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, drop_last=True) logging.info(f'''Starting training: Epochs: {epochs} Batch size: {batch_size} Learning rate: {lr} Training size: {n_train} Validation size: {n_valid} Device: {device.type} ''') writer = SummaryWriter() global_step = 0 optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9) scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer) loss_func = nn.CrossEntropyLoss( ) if model.n_classes > 1 else nn.BCEWithLogitsLoss() for epoch in range(epochs): with tqdm(total=n_train, desc=f'Epoch {epoch + 1}/{epochs}', unit='image') as pbar: model.train() for image, mask in train_dl: loss, _ = loss_batch(model, device, loss_func, image, mask, optimizer) writer.add_scalar('Loss/train', loss, global_step) pbar.set_postfix(**{'loss (batch)': loss}) pbar.update(image.shape[0]) global_step += 1 if global_step % (len(dataset) // (10 * batch_size)) == 0: for tag, value in model.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step) writer.add_histogram('grads/' + tag, value.data.cpu().numpy(), global_step) model.eval() with torch.no_grad(): losses, nums = zip(*[ loss_batch(model, device, loss_func, image, mask) for image, mask in valid_dl ]) val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums) scheduler.step(val_loss) writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], global_step) if model.n_classes > 1: logging.info( 'Validation cross entropy: {}'.format(val_loss)) writer.add_scalar('Loss/test', val_loss, global_step) else: logging.info( 'Validation Dice Coeff: {}'.format(val_loss)) writer.add_scalar('Dice/test', val_loss, global_step) writer.add_images('images', image, global_step) if model.n_classes == 1: writer.add_images('masks/true', mask, global_step) image = image.to(device=device, dtype=torch.float32) writer.add_images('masks/pred', torch.sigmoid(model(image)) > 0.5, global_step) writer.close()
def train(opt): date = datetime.date(datetime.now()) logs = '../logs/' logdir = os.path.join(logs,str(date)) if not os.path.exists(logdir): os.mkdir(logdir) else: logdir = logdir+"_"+str(np.random.randint(0,1000)) os.mkdir(logdir) train_data = AllInOneData(opt.train_path,set='train',transforms=transforms.Compose([Normalizer(),Resizer()])) train_generator = torch.utils.data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,num_workers=8, collate_fn=collater,drop_last=True) valid_data = AllInOneData(opt.train_path,set='validation',transforms=transforms.Compose([Normalizer(),Resizer()])) valid_generator = torch.utils.data.DataLoader(valid_data,batch_size=opt.batch_size,shuffle=False,num_workers=8, collate_fn=collater,drop_last=True) device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model = EfficientDetMultiBackbone(opt.train_path,compound_coef=0,heads=opt.heads) model.to(device) min_val_loss = 10e5 if opt.optim == 'Adam': optimizer = torch.optim.AdamW(model.parameters(),lr=opt.lr) else: optimizer = torch.optim.SGD(model.parameters(),lr=opt.lr,momentum = opt.momentum,nesterov=True) scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, opt.lr, total_steps=None, epochs=opt.epochs, steps_per_epoch=len(train_generator), pct_start=0.1, anneal_strategy='cos', cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, div_factor=25.0, final_div_factor=1000.0, last_epoch=-1) criterion = MTLoss(heads = opt.heads, device = device) print('Model is successfully initiated') print(f'Targets are {opt.heads}.') verb_loss = 0 writer = SummaryWriter(logdir=logdir,filename_suffix=f'Train_{"_".join(opt.heads)}',comment='try1') for epoch in range(opt.epochs): model.train() Losses = {k:[] for k in opt.heads} description = f'Epoch:{epoch}| Total Loss:{verb_loss}' progress_bar = tqdm(train_generator,desc = description) Total_loss = [] for sample in progress_bar: imgs = sample['img'].to(device) gt_person_bbox = sample['person_bbox'].to(device) gt_face_bbox = sample['face_bbox'].to(device) gt_pose = sample['pose'].to(device) gt_face_landmarks = sample['face_landmarks'].to(device) gt_age = sample['age'].to(device) gt_race = sample['race'].to(device) gt_gender = sample['gender'].to(device) gt_skin = sample['skin'].to(device) gt_emotions = sample['emotion'].to(device) out = model(imgs) annot = {'person':gt_person_bbox,'gender':gt_gender, 'face':gt_face_bbox,'emotions':gt_emotions, 'face_landmarks':gt_face_landmarks, 'pose':gt_pose} losses, lm_mask = criterion(out,annot,out['anchors']) loss = torch.zeros(1).to(device) loss = torch.sum(torch.cat(list(losses.values()))) loss.backward() optimizer.step() scheduler.step() verb_loss = loss.detach().cpu().numpy() Total_loss.append(verb_loss) description = f'Epoch:{epoch}| Total Loss:{verb_loss}|' for k,v in losses.items(): Losses[k].append(v.detach().cpu().numpy()) description+=f'{k}:{round(np.mean(Losses[k]),1)}|' progress_bar.set_description(description) optimizer.zero_grad() writer.add_scalar('Train/Total',round(np.mean(Total_loss),2),epoch) for k in Losses.keys(): writer.add_scalar(f"Train/{k}",round(np.mean(Losses[k]),2),epoch) if epoch%opt.valid_step==0: im = (imgs[0]+1)/2*255 regressBoxes = BBoxTransform() clipBoxes = ClipBoxes() pp = postprocess(imgs, out['anchors'], out['person'], out['gender'], regressBoxes, clipBoxes, 0.4, 0.4) writer.add_image_with_boxes('Train/Box_prediction',im,pp[0]['rois'],epoch) img2 = out['face_landmarks'] if img2.shape[1]>3: img2 = img2.sum(axis=1).unsqueeze(1)*255 lm_mask = lm_mask.sum(axis=1).unsqueeze(1)*255 writer.add_images('Train/landmarks_prediction',img2,epoch) writer.add_images('Train/landmark target', lm_mask,epoch) #VALIDATION STEPS model.eval() with torch.no_grad(): valid_Losses = {k:[] for k in opt.heads} val_description = f'Validation| Total Loss:{verb_loss}' progress_bar = tqdm(valid_generator,desc = val_description) Total_loss = [] for sample in progress_bar: imgs = sample['img'].to(device) gt_person_bbox = sample['person_bbox'].to(device) gt_face_bbox = sample['face_bbox'].to(device) gt_pose = sample['pose'].to(device) gt_face_landmarks = sample['face_landmarks'].to(device) gt_age = sample['age'].to(device) gt_race = sample['race'].to(device) gt_gender = sample['gender'].to(device) gt_skin = sample['skin'].to(device) gt_emotions = sample['emotion'].to(device) out = model(imgs) annot = {'person':gt_person_bbox,'gender':gt_gender, 'face':gt_face_bbox,'emotions':gt_emotions, 'face_landmarks':gt_face_landmarks, 'pose':gt_pose} losses, lm_mask = criterion(out,annot,out['anchors']) loss = torch.zeros(1).to(device) loss = torch.sum(torch.cat(list(losses.values()))) verb_loss = loss.detach().cpu().numpy() Total_loss.append(verb_loss) val_description = f'Validation| Total Loss:{verb_loss}|' for k,v in losses.items(): valid_Losses[k].append(v.detach().cpu().numpy()) val_description+=f'{k}:{round(np.mean(valid_Losses[k]),1)}|' progress_bar.set_description(val_description) writer.add_scalar('Validation/Total',round(np.mean(Total_loss),2),epoch) for k in valid_Losses.keys(): writer.add_scalar(f"Validation/{k}",round(np.mean(valid_Losses[k]),2),epoch) im = (imgs[0]+1)/2*255 regressBoxes = BBoxTransform() clipBoxes = ClipBoxes() pp = postprocess(imgs, out['anchors'], out['person'], out['gender'], regressBoxes, clipBoxes, 0.4, 0.4) writer.add_image_with_boxes('Validation/Box_prediction',im,pp[0]['rois'],epoch) img2 = out['face_landmarks'] if img2.shape[1]>3: img2 = img2.sum(axis=1).unsqueeze(1)*255 lm_mask = lm_mask.sum(axis=1).unsqueeze(1)*255 writer.add_images('Validation/landmarks_prediction',img2,epoch) writer.add_images('Validation/landmark target', lm_mask,epoch) if verb_loss<min_val_loss: print("The model improved and checkpoint is saved.") torch.save(model.state_dict(),f'{logdir}/{opt.save_name.split(".pt")[0]}_best_epoch_{epoch}.pt') min_val_loss = verb_loss if epoch%100==0: torch.save(model.state_dict(),f'{logdir}/{opt.save_name.split(".pt")[0]}_epoch_{epoch}.pt') torch.save(model.state_dict(),f'{logdir}/{opt.save_name.split(".pt")[0]}_last.pt') writer.close()
def train(args): writer = SummaryWriter(comment=args.writer) # data loader setting, train and evaluation data_loader = get_loader(args.dataset) data_path = get_data_path(args.dataset) t_loader = data_loader(data_path, split='train', img_size=(args.img_rows, args.img_cols), img_norm=args.img_norm) v_loader = data_loader(data_path, split='test', img_size=(args.img_rows, args.img_cols), img_norm=args.img_norm) trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers) evalloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=args.num_workers) print("Finish Loader Setup") # Setup Model and load pretrained model model_name = args.arch_RGB # print(model_name) model = get_model(model_name, True) # vgg_16 if args.pretrain: # True by default if args.input == 'rgb': # only for rgb we have pretrain option state = get_premodel(model, args.state_name) model.load_state_dict(state) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) elif args.input == 'd': # for d, load from result from... print("Load training model: {}_{}_{}_{}_best.pkl".format(args.arch_RGB, args.dataset, args.loss, 1)) checkpoint = torch.load(pjoin(args.model_savepath_pretrain, "{}_{}_{}_{}_best.pkl".format(args.arch_RGB, args.dataset, args.loss, 1))) # model.load_state_dict(load_resume_state_dict(model, checkpoint['model_D_state'])) model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) model.load_state_dict(checkpoint['model_D_state']) else: model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count())) # model_RGB = DataParallelWithCallback(model_RGB, device_ids=range(torch.cuda.device_count())) model.cuda() print("Finish model setup with model %s and state_dict %s" % (args.arch_RGB, args.state_name)) # optimizers and lr-decay setting if args.pretrain: # True by default optimizer_RGB = torch.optim.RMSprop(model.parameters(), lr=0.25 * args.l_rate) scheduler_RGB = torch.optim.lr_scheduler.MultiStepLR(optimizer_RGB, milestones=[1, 2, 4, 8], gamma=0.5) else: optimizer_RGB = torch.optim.RMSprop(model.parameters(), lr=args.l_rate) scheduler_RGB = torch.optim.lr_scheduler.MultiStepLR(optimizer_RGB, milestones=[1, 3, 5, 8, 11, 15], gamma=0.5) # forward and backward best_loss = 3 n_iter_t, n_iter_v = 0, 0 if args.dataset == 'matterport': total_iter_t = 105432 / args.batch_size elif args.dataset == 'scannet': total_iter_t = 59743 / args.batch_size else: total_iter_t = 0 if not os.path.exists(args.model_savepath): os.makedirs(args.model_savepath) for epoch in range(args.n_epoch): scheduler_RGB.step() model.train() for i, (images, labels, masks, valids, depthes, meshdepthes) in enumerate(trainloader): n_iter_t += 1 images = Variable(images.contiguous().cuda()) labels = Variable(labels.contiguous().cuda()) masks = Variable(masks.contiguous().cuda()) optimizer_RGB.zero_grad() if args.input == 'rgb': outputs = model(images) else: depthes = Variable(depthes.contiguous().cuda()) if args.input == 'rgbd': rgbd_input = torch.cat((images, depthes), dim=1) outputs = model(rgbd_input) elif args.input == 'd': outputs = model(depthes) loss, df = get_lossfun(args.loss, outputs, labels, masks) if args.l1regular: loss_rgl, df_rgl = get_lossfun('l1gra', outputs, labels, masks) elif args.gradloss: loss_grad, df_grad = get_lossfun('gradmap', outputs, labels, masks) if args.l1regular: outputs.backward(gradient=df, retain_graph=True) outputs.backward(gradient=0.1 * df_rgl) elif args.gradloss: outputs.backward(gradient=df, retain_graph=True) outputs.backward(gradient=0.5 * df_grad) else: outputs.backward(gradient=df) optimizer_RGB.step() if (i + 1) % 100 == 0: if args.l1regular: print("Epoch [%d/%d] Iter [%d/%d] Loss and RGL: %.4f, %.4f" % ( epoch + 1, args.n_epoch, i, total_iter_t, loss.data, loss_rgl.data)) elif args.gradloss: print("Epoch [%d/%d] Iter [%d/%d] Loss and GradLoss: %.4f, %.4f" % ( epoch + 1, args.n_epoch, i, total_iter_t, loss.data, loss_grad.data)) else: print("Epoch [%d/%d] Iter [%d/%d] Loss: %.4f" % ( epoch + 1, args.n_epoch, i, total_iter_t, loss.data)) if (i + 1) % 250 == 0: writer.add_scalar('loss/trainloss', loss.data.item(), n_iter_t) if args.l1regular: writer.add_scalar('loss/trainloss_rgl', loss_rgl.data.item(), n_iter_t) elif args.gradloss: writer.add_scalar('loss/trainloss_grad', loss_grad.data.item(), n_iter_t) writer.add_images('Image', images + 0.5, n_iter_t) if args.input != 'rgb': writer.add_images('Depth', np.repeat( (depthes - torch.min(depthes)) / (torch.max(depthes) - torch.min(depthes)), 3, axis=1), n_iter_t) writer.add_images('Label', 0.5 * (labels.permute(0, 3, 1, 2) + 1), n_iter_t) outputs_n = norm_tf(outputs) writer.add_images('Output', outputs_n, n_iter_t) model.eval() mean_loss, sum_loss, sum_rgl, sum_grad = 0, 0, 0, 0 evalcount = 0 with torch.no_grad(): for i_val, (images_val, labels_val, masks_val, valids_val, depthes_val, meshdepthes_val) in tqdm( enumerate(evalloader)): n_iter_v += 1 images_val = Variable(images_val.contiguous().cuda()) labels_val = Variable(labels_val.contiguous().cuda()) masks_val = Variable(masks_val.contiguous().cuda()) if args.input == 'rgb': outputs = model(images_val) else: depthes_val = Variable(depthes_val.contiguous().cuda()) if args.input == 'rgbd': rgbd_input = torch.cat((images_val, depthes_val), dim=1) outputs = model(rgbd_input) elif args.input == 'd': outputs = model(depthes_val) loss, df = get_lossfun(args.loss, outputs, labels_val, masks_val, False) # valid_val not used infact if args.l1regular: loss_rgl, df_rgl = get_lossfun('l1gra', outputs, labels_val, masks_val, False) elif args.gradloss: loss_grad, df_grad = get_lossfun('gradmap', outputs, labels_val, masks_val, False) if ((np.isnan(loss)) | (np.isinf(loss))): sum_loss += 0 else: sum_loss += loss evalcount += 1 if args.l1regular: sum_rgl += loss_rgl elif args.gradloss: sum_grad += loss_grad if (i_val + 1) % 250 == 0: # print("Epoch [%d/%d] Evaluation Loss: %.4f" % (epoch+1, args.n_epoch, loss)) writer.add_scalar('loss/evalloss', loss, n_iter_v) writer.add_images('Eval Image', images_val + 0.5, n_iter_t) if args.input != 'rgb': writer.add_image('Depth', np.repeat( (depthes_val - torch.min(depthes_val)) / (torch.max(depthes_val) - torch.min(depthes_val)), 3, axis=1), n_iter_t) writer.add_images('Eval Label', 0.5 * (labels_val.permute(0, 3, 1, 2) + 1), n_iter_t) outputs_n = norm_tf(outputs) writer.add_images('Eval Output', outputs_n, n_iter_t) mean_loss = sum_loss / evalcount print("Epoch [%d/%d] Evaluation Mean Loss: %.4f" % (epoch + 1, args.n_epoch, mean_loss)) writer.add_scalar('loss/evalloss_mean', mean_loss, epoch) writer.add_scalar('loss/evalloss_rgl_mean', sum_rgl / evalcount, epoch) writer.add_scalar('loss/evalloss_grad_mean', sum_grad / evalcount, epoch) if mean_loss < best_loss: # if (epoch+1)%20 == 0: best_loss = mean_loss state = {'epoch': epoch + 1, 'model_RGB_state': model.state_dict(), 'optimizer_RGB_state': optimizer_RGB.state_dict(), } if args.pretrain: if args.l1regular: torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_rgls_best.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) elif args.gradloss: torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_grad_best.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) else: torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_resume_RGB_best.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) else: torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_resume_RGB_best.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) print('Finish training for dataset %s trial %s' % (args.dataset, args.model_num)) # state = {'epoch': epoch+1, # 'model_RGB_state': model_RGB.state_dict(), # 'optimizer_RGB_state' : optimizer_RGB.state_dict(),} # if args.pretrain: # torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_RGB_final.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) # elif args.l1regular: # torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_rgls_final.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) # else: # torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_nopretrain_final.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num))) writer.export_scalars_to_json("./{}_{}_{}_{}.json".format(args.arch_RGB, args.dataset, args.loss, args.model_num)) writer.close()
fd_optimizer.step() pd_optimizer.zero_grad() pd_loss.backward() pd_optimizer.step() if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter: #torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_{i + 1}.pth') #torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_{i + 1}.pth') #torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_{i + 1}.pth') torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_10000.pth') torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_10000.pth') torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_10000.pth') if (i + 1) % args.log_interval == 0: writer.add_scalar('g_loss/recon_loss', recon_loss.item(), i + 1) writer.add_scalar('g_loss/cons_loss', cons_loss.item(), i + 1) writer.add_scalar('g_loss/gan_loss', gan_loss.item(), i + 1) writer.add_scalar('g_loss/total_loss', total_loss.item(), i + 1) writer.add_scalar('d_loss/fd_loss', fd_loss.item(), i + 1) writer.add_scalar('d_loss/pd_loss', pd_loss.item(), i + 1) def denorm(x): out = (x + 1) / 2 # [-1,1] -> [0,1] return out.clamp_(0, 1) if (i + 1) % args.vis_interval == 0: ims = torch.cat([img, masked, refine_result], dim=3) writer.add_images('raw_masked_refine', denorm(ims), i + 1) writer.close()
def main(config): matrix = torch.load("matrix_obj_vs_att.pt") cudnn.benchmark = True device = torch.device('cuda:1') log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir( config.exp_name) attribute_nums = 106 data_loader, _ = get_dataloader_vg(batch_size=config.batch_size, attribute_embedding=attribute_nums, image_size=config.image_size) vocab_num = data_loader.dataset.num_objects if config.clstm_layers == 0: netG = Generator_nolstm(num_embeddings=vocab_num, embedding_dim=config.embedding_dim, z_dim=config.z_dim).to(device) else: netG = Generator(num_embeddings=vocab_num, obj_att_dim=config.embedding_dim, z_dim=config.z_dim, clstm_layers=config.clstm_layers, obj_size=config.object_size, attribute_dim=attribute_nums).to(device) netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device) netD_object = ObjectDiscriminator(n_class=vocab_num).to(device) netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device) netD_image = add_sn(netD_image) netD_object = add_sn(netD_object) netD_att = add_sn(netD_att) netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate, [0.5, 0.999]) netD_image_optimizer = torch.optim.Adam(netD_image.parameters(), config.learning_rate, [0.5, 0.999]) netD_object_optimizer = torch.optim.Adam(netD_object.parameters(), config.learning_rate, [0.5, 0.999]) netD_att_optimizer = torch.optim.Adam(netD_att.parameters(), config.learning_rate, [0.5, 0.999]) start_iter_ = load_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=config.resume_iter) start_iter_ = load_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter=config.resume_iter) start_iter_ = load_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=config.resume_iter) start_iter = load_model(netG, model_dir=model_save_dir, appendix='netG', iter=config.resume_iter) data_iter = iter(data_loader) if start_iter < config.niter: if config.use_tensorboard: writer = SummaryWriter(log_save_dir) for i in range(start_iter, config.niter): # =================================================================================== # # 1. Preprocess input data # # =================================================================================== # try: batch = next(data_iter) except: data_iter = iter(data_loader) batch = next(data_iter) imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch z = torch.randn(objs.size(0), config.z_dim) att_idx = attribute.sum(dim=1).nonzero().squeeze() # print("Train D") # =================================================================================== # # 2. Train the discriminator # # =================================================================================== # imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift \ = imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), obj_to_img, z.to( device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device) attribute_GT = attribute.clone() # estimate attributes attribute_est = attribute.clone() att_mask = torch.zeros(attribute.shape[0]) att_mask = att_mask.scatter(0, att_idx, 1).to(device) crops_input = crop_bbox_batch(imgs, boxes, obj_to_img, config.object_size) estimated_att = netD_att(crops_input) max_idx = estimated_att.argmax(1) max_idx = max_idx.float() * (~att_mask.byte()).float().to(device) for row in range(attribute.shape[0]): if row not in att_idx: attribute_est[row, int(max_idx[row])] = 1 # change GT attribute: num_img_to_change = math.floor(imgs.shape[0] / 3) for img_idx in range(num_img_to_change): obj_indices = torch.nonzero(obj_to_img == img_idx).view(-1) num_objs_to_change = math.floor(len(obj_indices) / 2) for changed, obj_idx in enumerate(obj_indices): if changed >= num_objs_to_change: break obj = objs[obj_idx] # change GT attribute old_attributes = torch.nonzero( attribute_GT[obj_idx]).view(-1) new_attribute = random.choices(range(106), matrix[obj].scatter( 0, old_attributes.cpu(), 0), k=random.randrange(1, 3)) attribute[obj_idx] = 0 # remove all attributes for obj attribute[obj_idx] = attribute[obj_idx].scatter( 0, torch.LongTensor(new_attribute).to(device), 1) # assign new attribute # change estimated attributes attribute_est[obj_idx] = 0 # remove all attributes for obj attribute_est[obj_idx] = attribute[obj_idx].scatter( 0, torch.LongTensor(new_attribute).to(device), 1) # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # Compute image adv loss with fake images. out_logits = netD_image(img_rec.detach()) d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) out_logits = netD_image(img_rand.detach()) d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # shift image adv loss out_logits = netD_image(img_shift.detach()) d_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_image_adv_loss_fake = 0.4 * d_image_adv_loss_fake_rec + 0.4 * d_image_adv_loss_fake_rand + 0.2 * d_image_adv_loss_fake_shift # Compute image src loss with real images rec. out_logits = netD_image(imgs) d_image_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # Compute object sn adv loss with fake rec crops out_logits, _ = netD_object(crops_input_rec.detach(), objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # Compute object sn adv loss with fake rand crops out_logits, _ = netD_object(crops_rand.detach(), objs) d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) # shift obj adv loss out_logits, _ = netD_object(crops_shift.detach(), objs) d_object_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 0)) d_object_adv_loss_fake = 0.4 * g_object_adv_loss_rec + 0.4 * d_object_adv_loss_fake_rand + 0.2 * d_object_adv_loss_fake_shift # Compute object sn adv loss with real crops. out_logits_src, out_logits_cls = netD_object( crops_input.detach(), objs) d_object_adv_loss_real = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) # cls d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_input.detach()) att_idx = attribute_GT.sum(dim=1).nonzero().squeeze() att_cls_annotated = torch.index_select(att_cls, 0, att_idx) attribute_annotated = torch.index_select(attribute_GT, 0, att_idx) d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) # Backward and optimize. d_loss = 0 d_loss += config.lambda_img_adv * (d_image_adv_loss_fake + d_image_adv_loss_real) d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake + d_object_adv_loss_real) d_loss += config.lambda_obj_cls * d_object_cls_loss_real d_loss += config.lambda_att_cls * d_object_att_cls_loss_real netD_image.zero_grad() netD_object.zero_grad() netD_att.zero_grad() d_loss.backward() netD_image_optimizer.step() netD_object_optimizer.step() netD_att_optimizer.step() # Logging. loss = {} loss['D/loss'] = d_loss.item() loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item() loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item() loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item() loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item() loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item() loss['D/object_att_cls_loss'] = d_object_att_cls_loss_real.item() # print("train G") # =================================================================================== # # 3. Train the generator # # =================================================================================== # # Generate fake image output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift, attribute_est) crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output # reconstruction loss of ae and img rec_img_mask = torch.ones(imgs.shape[0]).scatter( 0, torch.LongTensor(range(num_img_to_change)), 0).to(device) g_img_rec_loss = rec_img_mask * torch.abs(img_rec - imgs).view( imgs.shape[0], -1).mean(1) g_img_rec_loss = g_img_rec_loss.sum() / (imgs.shape[0] - num_img_to_change) g_z_rec_loss_rand = torch.abs(z_rand_rec - z).mean() g_z_rec_loss_shift = torch.abs(z_rand_shift - z).mean() g_z_rec_loss = 0.5 * g_z_rec_loss_rand + 0.5 * g_z_rec_loss_shift # kl loss kl_element = mu.pow(2).add_( logvar.exp()).mul_(-1).add_(1).add_(logvar) g_kl_loss = torch.sum(kl_element).mul_(-0.5) # Compute image adv loss with fake images. out_logits = netD_image(img_rec) g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) out_logits = netD_image(img_rand) g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) # shift image adv loss out_logits = netD_image(img_shift) g_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits( out_logits, torch.full_like(out_logits, 1)) g_image_adv_loss_fake = 0.4 * g_image_adv_loss_fake_rec + 0.4 * g_image_adv_loss_fake_rand + 0.2 * g_image_adv_loss_fake_shift # Compute object adv loss with fake images. out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs) g_object_adv_loss_rec = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_input_rec) att_idx = attribute.sum(dim=1).nonzero().squeeze() attribute_annotated = torch.index_select(attribute, 0, att_idx) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_rec = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) out_logits_src, out_logits_cls = netD_object(crops_rand, objs) g_object_adv_loss_rand = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_rand) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_rand = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) # shift adv obj loss out_logits_src, out_logits_cls = netD_object(crops_shift, objs) g_object_adv_loss_shift = F.binary_cross_entropy_with_logits( out_logits_src, torch.full_like(out_logits_src, 1)) g_object_cls_loss_shift = F.cross_entropy(out_logits_cls, objs) # attribute att_cls = netD_att(crops_shift) att_cls_annotated = torch.index_select(att_cls, 0, att_idx) g_object_att_cls_loss_shift = F.binary_cross_entropy_with_logits( att_cls_annotated, attribute_annotated, pos_weight=pos_weight.to(device)) g_object_att_cls_loss = 0.4 * g_object_att_cls_loss_rec + 0.4 * g_object_att_cls_loss_rand + 0.2 * g_object_att_cls_loss_shift g_object_adv_loss = 0.4 * g_object_adv_loss_rec + 0.4 * g_object_adv_loss_rand + 0.2 * g_object_adv_loss_shift g_object_cls_loss = 0.4 * g_object_cls_loss_rec + 0.4 * g_object_cls_loss_rand + 0.2 * g_object_cls_loss_shift # Backward and optimize. g_loss = 0 g_loss += config.lambda_img_rec * g_img_rec_loss g_loss += config.lambda_z_rec * g_z_rec_loss g_loss += config.lambda_img_adv * g_image_adv_loss_fake g_loss += config.lambda_obj_adv * g_object_adv_loss g_loss += config.lambda_obj_cls * g_object_cls_loss g_loss += config.lambda_att_cls * g_object_att_cls_loss g_loss += config.lambda_kl * g_kl_loss netG.zero_grad() g_loss.backward() netG_optimizer.step() loss['G/loss'] = g_loss.item() loss['G/image_adv_loss'] = g_image_adv_loss_fake.item() loss['G/object_adv_loss'] = g_object_adv_loss.item() loss['G/object_cls_loss'] = g_object_cls_loss.item() loss['G/rec_img'] = g_img_rec_loss.item() loss['G/rec_z'] = g_z_rec_loss.item() loss['G/kl'] = g_kl_loss.item() loss['G/object_att_cls_loss'] = g_object_att_cls_loss.item() # =================================================================================== # # 4. Log # # =================================================================================== # if (i + 1) % config.log_step == 0: log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter) for tag, roi_value in loss.items(): log += ", {}: {:.4f}".format(tag, roi_value) print(log) if (i + 1 ) % config.tensorboard_step == 0 and config.use_tensorboard: for tag, roi_value in loss.items(): writer.add_scalar(tag, roi_value, i + 1) writer.add_images( 'Result/crop_real', imagenet_deprocess_batch(crops_input).float() / 255, i + 1) writer.add_images( 'Result/crop_real_rec', imagenet_deprocess_batch(crops_input_rec).float() / 255, i + 1) writer.add_images( 'Result/crop_rand', imagenet_deprocess_batch(crops_rand).float() / 255, i + 1) writer.add_images('Result/img_real', imagenet_deprocess_batch(imgs).float() / 255, i + 1) writer.add_images( 'Result/img_real_rec', imagenet_deprocess_batch(img_rec).float() / 255, i + 1) writer.add_images( 'Result/img_fake_rand', imagenet_deprocess_batch(img_rand).float() / 255, i + 1) if (i + 1) % config.save_step == 0: # netG_noDP.load_state_dict(new_state_dict) save_model(netG, model_dir=model_save_dir, appendix='netG', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_image, model_dir=model_save_dir, appendix='netD_image', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_object, model_dir=model_save_dir, appendix='netD_object', iter=i + 1, save_num=2, save_step=config.save_step) save_model(netD_att, model_dir=model_save_dir, appendix='netD_attribute', iter=i + 1, save_num=2, save_step=config.save_step) if config.use_tensorboard: writer.close()
image, func_label2color=visualization.VOClabel2colormap, threshold=None, norm=False) # MASK = eq_mask[0].detach().cpu().numpy().astype(np.uint8)*255 loss_dict = { 'loss': loss.item(), 'loss_cls': loss_cls.item(), 'loss_er': loss_er.item(), 'loss_ecr': loss_ecr.item() } itr = optimizer.global_step - 1 tblogger.add_scalars('loss', loss_dict, itr) tblogger.add_scalar('lr', optimizer.param_groups[0]['lr'], itr) tblogger.add_image('Image', input_img, itr) # tblogger.add_image('Mask', MASK, itr) tblogger.add_image('CLS1', CLS1, itr) tblogger.add_image('CLS2', CLS2, itr) tblogger.add_image('CLS_RV1', CLS_RV1, itr) tblogger.add_image('CLS_RV2', CLS_RV2, itr) tblogger.add_images('CAM1', CAM1, itr) tblogger.add_images('CAM2', CAM2, itr) tblogger.add_images('CAM_RV1', CAM_RV1, itr) tblogger.add_images('CAM_RV2', CAM_RV2, itr) else: print('') timer.reset_stage() torch.save(model.module.state_dict(), args.session_name + '.pth')
def train_gan(dataloader, model_folder, netG, netD, netS, netEs, netEb, args): """ Parameters: ---------- dataloader: data loader. refers to fuel.dataset model_root: the folder to save the models weights netG: Generator netD: Descriminator netS: Segmentation Network netEs: Segmentation Encoder netEb: Background Encoder """ d_lr = args.d_lr g_lr = args.g_lr tot_epoch = args.maxepoch ''' configure optimizers ''' optimizerD = optim.Adam(netD.parameters(), lr=d_lr, betas=(0.5, 0.999)) paramsG = list(netG.parameters()) + list(netEs.parameters()) + list( netEb.parameters()) optimizerG = optim.Adam(paramsG, lr=g_lr, betas=(0.5, 0.999)) ''' create tensorboard writer ''' writer = SummaryWriter(model_folder) # --- load model from checkpoint --- netS.load_state_dict(torch.load(args.unet_checkpoint)) if args.reuse_weights: G_weightspath = os.path.join( model_folder, 'G_epoch{}.pth'.format(args.load_from_epoch)) D_weightspath = os.path.join( model_folder, 'D_epoch{}.pth'.format(args.load_from_epoch)) Es_weightspath = os.path.join( model_folder, 'Es_epoch{}.pth'.format(args.load_from_epoch)) Eb_weightspath = os.path.join( model_folder, 'Eb_epoch{}.pth'.format(args.load_from_epoch)) netG.load_state_dict(torch.load(G_weightspath)) netD.load_state_dict(torch.load(D_weightspath)) netEs.load_state_dict(torch.load(Es_weightspath)) netEb.load_state_dict(torch.load(Eb_weightspath)) start_epoch = args.load_from_epoch + 1 d_lr /= 2**(start_epoch // args.epoch_decay) g_lr /= 2**(start_epoch // args.epoch_decay) else: start_epoch = 1 # --- Start training --- for epoch in range(start_epoch, tot_epoch + 1): start_timer = time.time() '''decay learning rate every epoch_decay epoches''' if epoch % args.epoch_decay == 0: d_lr = d_lr / 2 g_lr = g_lr / 2 set_lr(optimizerD, d_lr) set_lr(optimizerG, g_lr) netG.train() netD.train() netEs.train() netEb.train() netS.eval() for i, data in enumerate(dataloader): images, w_images, segs, txt_data, txt_len, _ = data # create labels r_labels = torch.FloatTensor(images.size(0)).fill_(1).cuda() f_labels = torch.FloatTensor(images.size(0)).fill_(0).cuda() it = epoch * len(dataloader) + i # to cuda images = images.cuda() w_images = w_images.cuda() segs = segs.cuda() txt_data = txt_data.cuda() ''' UPDATE D ''' for p in netD.parameters(): p.requires_grad = True optimizerD.zero_grad() if args.manipulate: bimages = images # for text and seg mismatched backgrounds bsegs = segs # background segmentations else: bimages = roll( images, 2, dim=0) # for text and seg mismatched backgrounds bsegs = roll(segs, 2, dim=0) # background segmentations segs = roll(segs, 1, dim=0) # for text mismatched segmentations segs_code = netEs(segs) # segmentation encoding bkgs_code = netEb(bimages) # background image encoding mean_var, smean_var, bmean_var, f_images, z_list = netG( txt_data, txt_len, segs_code, bkgs_code) f_images_cp = f_images.data.cuda() r_logit, r_logit_c = netD(images, txt_data, txt_len) _, w_logit_c = netD(w_images, txt_data, txt_len) f_logit, _ = netD(f_images_cp, txt_data, txt_len) d_adv_loss = compute_d_loss(r_logit, r_logit_c, w_logit_c, f_logit, r_labels, f_labels) d_loss = d_adv_loss d_loss.backward() optimizerD.step() optimizerD.zero_grad() ''' UPDATE G ''' for p in netD.parameters(): p.requires_grad = False # to avoid computation optimizerG.zero_grad() f_logit, f_logit_c = netD(f_images, txt_data, txt_len) g_adv_loss = compute_g_loss(f_logit, f_logit_c, r_labels) f_segs = netS(f_images) # segmentation from Unet seg_consist_loss = shape_consistency_loss(f_segs, segs) bkg_consist_loss = background_consistency_loss( f_images, bimages, f_segs, bsegs) kl_loss = get_kl_loss(mean_var[0], mean_var[1]) # text skl_loss = get_kl_loss(smean_var[0], smean_var[1]) # segmentation bkl_loss = get_kl_loss(bmean_var[0], bmean_var[1]) # background if args.manipulate: idt_consist_loss = idt_consistency_loss(f_images, images) else: idt_consist_loss = 0. g_loss = g_adv_loss \ + args.KL_COE * kl_loss \ + args.KL_COE * skl_loss \ + args.KL_COE * bkl_loss \ + 10 * seg_consist_loss \ + 10 * bkg_consist_loss \ + 10 * idt_consist_loss g_loss.backward() optimizerG.step() optimizerG.zero_grad() # --- visualize train samples---- if it % args.verbose_per_iter == 0: writer.add_images('txt', (images[:args.n_plots] + 1) / 2, it) writer.add_images('background', (bimages[:args.n_plots] + 1) / 2, it) writer.add_images('segmentation', segs[:args.n_plots].repeat(1, 3, 1, 1), it) writer.add_images('generated', (f_images[:args.n_plots] + 1) / 2, it) writer.add_scalar('g_lr', g_lr, it) writer.add_scalar('d_lr', g_lr, it) writer.add_scalar('g_loss', to_numpy(g_loss).mean(), it) writer.add_scalar('d_loss', to_numpy(d_loss).mean(), it) writer.add_scalar('imkl_loss', to_numpy(kl_loss).mean(), it) writer.add_scalar('segkl_loss', to_numpy(skl_loss).mean(), it) writer.add_scalar('bkgkl_loss', to_numpy(bkl_loss).mean(), it) writer.add_scalar('seg_consist_loss', to_numpy(seg_consist_loss).mean(), it) writer.add_scalar('bkg_consist_loss', to_numpy(bkg_consist_loss).mean(), it) if args.manipulate: writer.add_scalar('idt_consist_loss', to_numpy(idt_consist_loss).mean(), it) # --- save weights --- if epoch % args.save_freq == 0: netG = netG.cpu() netD = netD.cpu() netEs = netEs.cpu() netEb = netEb.cpu() torch.save( netD.state_dict(), os.path.join(model_folder, 'D_epoch{}.pth'.format(epoch))) torch.save( netG.state_dict(), os.path.join(model_folder, 'G_epoch{}.pth'.format(epoch))) torch.save( netEs.state_dict(), os.path.join(model_folder, 'Es_epoch{}.pth'.format(epoch))) torch.save( netEb.state_dict(), os.path.join(model_folder, 'Eb_epoch{}.pth'.format(epoch))) print('save weights at {}'.format(model_folder)) netD = netD.cuda() netG = netG.cuda() netEs = netEs.cuda() netEb = netEb.cuda() end_timer = time.time() - start_timer print('epoch {}/{} finished [time = {}s] ...'.format( epoch, tot_epoch, end_timer)) writer.close()