Exemplo n.º 1
0
def train_loop(model, loader, test_loader, opt):
    device = torch.device('cuda:{}'.format(opt.cuda))
    print(opt.exp)
    optim = torch.optim.Adam(model.parameters(), 5e-4, betas=(0.5, 0.999))
    writer = SummaryWriter('tblog/%s' % opt.exp)
    for e in tqdm(range(opt.epochs)):
        losses = []
        model.train()
        for (x, _) in tqdm(loader):
            x = x.to(device)
            if x.size(1) == 1:
                x = x.repeat(1, 3, 1, 1)

            x.requires_grad = False
            out = model(x)
            rec_err = (out - x)**2
            loss = rec_err.mean()
            losses.append(loss.item())

            optim.zero_grad()
            loss.backward()
            optim.step()

        losses = np.mean(losses)
        writer.add_scalar('rec_err', losses, e)
        writer.add_images('recons', torch.cat((x, out)).cpu() * 0.5 + 0.5, e)
        print('epochs:{}, recon error:{}'.format(e, losses))

    torch.save(model.state_dict(), 'models/{}.pth'.format(opt.exp))
Exemplo n.º 2
0
    def train(self,
              data_loader,
              epochs=20,
              log_dir="runs/test/",
              log_freq=500):
        """ run training loop
        """
        tb_logger = SummaryWriter(log_dir=log_dir)
        self.log_dir = log_dir
        constant_noise = torch.randn(64, self.latent_dim, device=self.device)
        # to be used for tensborboard-logging only

        for ep in range(1, epochs + 1):

            print("\n", "=" * 35, f"training epoch {ep}", "=" * 35, "\n")

            for it, (imgs, _) in tqdm(enumerate(data_loader)):
                self.glob_it += 1

                imgs = imgs.to(self.device)
                enc_out = self.enc(imgs)
                mu, logvar = enc_out.chunk(chunks=2, dim=1)

                # reparamnetrizesation
                std = torch.exp(0.5 * logvar)
                e = torch.randn(std.shape, device=self.device)
                z = mu + e * std

                imgs_recon = self.dec(z)

                loss = self.ELBO(imgs_recon, imgs, mu, logvar)

                self.enc.zero_grad()
                self.dec.zero_grad()
                loss.backward()
                self.optim.step()

                tb_logger.add_scalar("train_loss", loss, self.glob_it)

                if self.glob_it % log_freq == 0:
                    # log some images to tensorboard
                    tb_logger.add_figure("samples",
                                         self.get_mXn_samples_grid(4, 4),
                                         self.glob_it)
                    print(
                        f"epoch {ep}, iter {it} (total iter {self.glob_it}): train_loss = {loss}"
                    )

            # per epoch logging
            print(
                f"epoch {ep}, iter {it} (total iter {self.glob_it}): train_loss = {loss}"
            )
            tb_logger.add_images("epoch/sample1",
                                 self.sample(noise=constant_noise), ep)
            tb_logger.add_images("epoch/sample2", self.sample(num_images=64),
                                 ep)
            # save model at end of each epoch
            self.save_model(model_name=f"vae_ep{ep}.pt", idx=self.glob_it)
Exemplo n.º 3
0
class TensorboardVisualizer:
    """ A wrapper class for tensorboardX.

    Note:
        Original tensorboardX API call is supported.
    """
    def __init__(self, experiment=None, **kwargs):
        comment = '_' + str(experiment) if experiment is not None else ''
        if self.experiment_exists(experiment):
            self.has_writer = False
            raise ValueError(
                'experiment [{}] already exists.'.format(experiment))
        else:
            self.has_writer = True
        self.writer = SummaryWriter(comment=comment, **kwargs)

    def __del__(self):
        if self.has_writer:
            self.writer.close()

    @staticmethod
    def experiment_exists(experiment):
        if not mv.isdir('runs'):
            return False

        all_experiments = mv.listdir('runs')
        all_experiments = [e.split('_', 3)[-1] for e in all_experiments]
        if experiment in all_experiments:
            return True
        else:
            return False

    def plot(self, name, x, y, group='data'):
        tag = group + '/' + name
        self.writer.add_scalar(tag, y, x)

    def imshow(self, name, img_tensor, global_step=None, group='image'):
        """ Accept NCHW numpy.ndarray or torch tensor as input.
        """
        imgs = mv.make_np(img_tensor)
        assert imgs.ndim == 4

        if imgs.shape[1] == 1:
            imgs = np.concatenate([imgs, imgs, imgs], 1)

        tag = group + '/' + name
        self.writer.add_images(tag, imgs, global_step)

    def log(self, info, global_step=None, tag='text/log'):
        text_string = '[{time}] {info} <br>'.format(
            time=time.strftime('%y-%m-%d %H:%M:%S'), info=info)
        self.writer.add_text(tag, text_string, global_step)

    def __getattr__(self, name):
        return getattr(self.writer, name)
Exemplo n.º 4
0
class MetricCounter:
    def __init__(self, exp_name=None):
        self.writer = SummaryWriter(exp_name)
        logging.basicConfig(filename='{}.log'.format(exp_name),
                            level=logging.DEBUG)
        self.metrics = defaultdict(list)
        self.images = defaultdict(list)
        self.best_metric = 0

    def add_image(self, x: np.ndarray, tag: str):
        self.images[tag].append(x)

    def clear(self):
        self.metrics = defaultdict(list)
        self.images = defaultdict(list)

    def add_losses(self, l_G, l_content, l_D=0):
        for name, value in zip(
            ('G_loss', 'G_loss_content', 'G_loss_adv', 'D_loss'),
            (l_G, l_content, l_G - l_content, l_D)):
            self.metrics[name].append(value)

    def add_metrics(self, psnr, ssim):
        for name, value in zip(('PSNR', 'SSIM'), (psnr, ssim)):
            self.metrics[name].append(value)

    def loss_message(self):
        metrics = ((k, np.mean(self.metrics[k][-WINDOW_SIZE:]))
                   for k in ('G_loss', 'PSNR', 'SSIM'))
        return '; '.join(map(lambda x: f'{x[0]}={x[1]:.4f}', metrics))

    def write_to_tensorboard(self, epoch_num, validation=False):
        scalar_prefix = 'Validation' if validation else 'Train'
        for tag in ('G_loss', 'D_loss', 'G_loss_adv', 'G_loss_content', 'SSIM',
                    'PSNR'):
            self.writer.add_scalar(f'{scalar_prefix}_{tag}',
                                   np.mean(self.metrics[tag]),
                                   global_step=epoch_num)
        for tag in self.images:
            imgs = self.images[tag]
            if imgs:
                imgs = np.array(imgs)
                self.writer.add_images(tag,
                                       imgs[:, :, :, ::-1].astype('float32') /
                                       255,
                                       dataformats='NHWC',
                                       global_step=epoch_num)
                self.images[tag] = []

    def update_best_model(self):
        cur_metric = np.mean(self.metrics['PSNR'])
        if self.best_metric < cur_metric:
            self.best_metric = cur_metric
            return True
        return False
Exemplo n.º 5
0
    def train(self):
        writer = SummaryWriter(log_dir="log_info")
        self.G.train()
        if self.opt.finetune:
            print("here")
            self.optm_G = optim.Adam(filter(lambda p:p.requires_grad, self.G.parameters()), lr = self.lr)
        train_loader = DataLoader(
            dataset=self.train_dataset,
            batch_size=self.opt.batch_size,
            num_workers=self.opt.n_threads,
            drop_last=True,
            shuffle=True
        )
        keep_training = True
        epoch = 0
        i = self.start_iter
        print("starting training")
        s_time = time.time()
        while keep_training:
            epoch += 1
            print("epoch: {:d}".format(epoch))
            for items in train_loader:
                i += 1
                gt_images, gray_image, gt_edges, masks = self.cuda(*items)
              #  masks = torch.cat([masks]*3, dim = 1)
                self.gray_image = gray_image
                
                masked_images = gt_images * masks
                masked_edges = gt_edges * masks[:,0:1,:,:]
                
                self.forward(masked_images, masks, masked_edges, gt_images, gt_edges)
                self.update_parameters()

                if i % self.opt.log_interval == 0:
                    e_time = time.time()
                    int_time = e_time - s_time
                    print("epoch:{:d}, iteration:{:d}".format(epoch, i), ", l1_loss:", self.l1_loss/self.opt.log_interval, ", time_taken:", int_time)
                    writer.add_scalars("loss_val", {"l1_loss":self.l1_loss*self.opt.batch_size/self.opt.log_interval, "D_loss":self.D_loss/self.opt.log_interval,"E_loss":self.E_loss*self.opt.batch_size/self.opt.log_interval}, i)
                    masked_images = masked_images.cpu()
                    fake_images = self.fake_B.cpu()
                    fake_edges = self.edge_fake[1].cpu()
                    fake_edges = torch.cat([fake_edges]*3, dim = 1)
                    images = torch.cat([masked_images[0:3], fake_images[0:3], fake_edges[0:3]], dim = 0)
                    writer.add_images("imgs", images, i)
                    s_time = time.time()
                    self.l1_loss = 0.0
                    self.D_loss = 0.0
                    self.E_loss = 0.0
                    
                if i % self.opt.save_interval == 0:
                    save_ckpt('{:s}/ckpt/g_{:d}.pth'.format(self.opt.save_dir, i ), [('generator', self.G)], [('optimizer_G', self.optm_G)], i )
                    if self.have_D:
                        save_ckpt('{:s}/ckpt/d_{:d}.pth'.format(self.opt.save_dir, i ), [('edge_D', self.edge_D)], [('optimizer_ED', self.optm_ED)], i )
                    
        writer.close()
Exemplo n.º 6
0
class Visualizer():
    def __init__(self, top_out_path
                 ):  # This will cause error in the very old train scripts
        self.writer = SummaryWriter(top_out_path)

    # |visuals|: dictionary of images to save
    def log_images(self, visuals, step):
        for label, image_numpy in visuals.items():
            self.writer.add_images(label, [image_numpy], step)

    # scalars: dictionary of scalar labels and values
    def log_scalars(self, scalars, step, main_tag='metrics'):
        self.writer.add_scalars(main_tag=main_tag,
                                tag_scalar_dict=scalars,
                                global_step=step)
Exemplo n.º 7
0
    def forward(self, x):
        self.x1 = x

        if self.test:
            TezhengTuWriter = SummaryWriter('./runs/Pict')
            TezhengTuWriter.add_image('countdown_1',
                                      self.x1[0],
                                      global_step=0,
                                      dataformats='CHW')

        x = self.conv1(x)
        self.x2 = x

        if self.test:
            TezhengTuWriter.add_images('countdown_2',
                                       GYH(self.x2[0]),
                                       global_step=1,
                                       dataformats='NCHW')

        x = self.conv2(x)
        self.x3 = x

        if self.test:
            TezhengTuWriter.add_images('countdown_3',
                                       GYH(self.x3[0]),
                                       global_step=2,
                                       dataformats='NCHW')

        x = self.conv3(x)
        self.x4 = x

        if self.test:
            TezhengTuWriter.add_images('countdown_4',
                                       GYH(self.x4[0]),
                                       global_step=3,
                                       dataformats='NCHW')

        x = self.conv4(x)
        self.x5 = x

        if self.test:
            TezhengTuWriter.add_images('countdown_5',
                                       GYH(self.x5[0]),
                                       global_step=4,
                                       dataformats='NCHW')
            TezhengTuWriter.close()
            self.test = 0

        x = x.view(
            x.size(0),
            -1)  # flatten the output of conv2 to (batch_size, 32 * 7 * 7)

        output = self.out(x)
        return output, x  # return x for visualization
Exemplo n.º 8
0
def mark_pru():
    net_e = load_encoder(after_f=True)
    net_e = nn.DataParallel(net_e).cuda()
    net_d = load_decoder(after_f=True)
    net_d = nn.DataParallel(net_d).cuda()

    data_loader = nyu_set.use_nyu_data(batch_s=1,
                                       max_len=400,
                                       isBenchmark=True)
    writer1 = SummaryWriter('/data/consistent_depth/gj_dir/benchmark_p2')

    with torch.no_grad():
        num = 0
        su = 0
        for data, label in data_loader:
            num += 1
            data = autograd.Variable(data.double().cuda(), requires_grad=False)

            prediction_d = net_d(net_e(data))

            abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 = compute_depth_errors(
                label, prediction_d)

            writer1.add_images('pre', prediction_d, global_step=num)

            writer1.add_scalar('rmse', rmse, global_step=num)
            writer1.add_scalar("abs_rel", abs_rel, global_step=num)
            writer1.add_scalar('sq_rel', sq_rel, global_step=num)
            writer1.add_scalar('rmse_log', rmse_log, global_step=num)
            writer1.add_scalar('a1', a1, global_step=num)
            writer1.add_scalar('a2', a2, global_step=num)
            writer1.add_scalar('a3', a3, global_step=num)

            writer1.add_images('label', label, global_step=num)
            su += a3.item()
            print(su / num)
            # scaled_disp, _ = disp_to_depth(disp, 0.1, 10)
            # Saving colormapped depth image
            # vmax = np.percentile(disp_resized_np, 95)
        writer1.close()
        print('-> Done!')
Exemplo n.º 9
0
class TensorboardLogger(Logger):
    def __init__(self, path):
        self.writer = SummaryWriter(path)

    def log_image(self, name, data, step):
        self.writer.add_image(name, data, step)

    def log_image_batch(self, name, data, step):
        self.writer.add_images(name, data, step)

    def log_number(self, name, data, step):
        self.writer.add_scalar(name, data, step)

    def log_text(self, name, data, step):
        self.writer.add_text(name, data, step)

    def log_figure(self, name, data, step):
        self.writer.add_figure(name, data, step)

    def log_embedding(self, name, data, step):
        self.writer.add_embedding(name, data, step)
Exemplo n.º 10
0
class Logger():
	def __init__(self, log_path='log'):
		self.logger = SummaryWriter(log_path)
		self.add_images_maxnum = 10

	def add_scalar(self, tag, value, step):
		if isinstance(value, dict):
			for k, v in value.items():
				self.logger.add_scalar(k, v.mean().item(), step)
		else:
			self.logger.add_scalar(tag, value, step)

	def add_image(self, tag, value, step, dataformats='CHW'):
		self.logger.add_image(tag, value, step)

	def add_images(self, tag, value, step, dataformats='CHW'):
		value = value[:self.add_images_maxnum]
		self.logger.add_images(tag, value, step)

	def add_graph(self, model, input_to_model=None):
		if input_to_model is None:
			input_to_model = torch.zeros((256, 3, 32, 128))
		self.add_graph(model, input_to_model)
Exemplo n.º 11
0
def benchmark_pruned():
    net = load_t_net(file=True)
    #net = load_pru_mod(after_finetune=True).double()
    net = nn.DataParallel(net)
    net = net.cuda()

    data_loader = nyu_set.use_nyu_data(batch_s=4,
                                       max_len=100,
                                       isBenchmark=True)
    writer1 = SummaryWriter('./gj_dir/benchmark_t_mod')

    Joint = JointLoss(opt=None).double().cuda()
    criterion = nn.MSELoss(reduction='mean').cuda()
    net.eval()

    num = 0
    for data, label in data_loader:
        num += 1
        target = label2target(label)

        images = autograd.Variable(images.double().cuda(), requires_grad=False)

        prediction_d = net.forward(images)[0]  # 0is depth .1 is confidence

        e_rmse = Joint.compute_rmse_error(prediction_d, target)
        e_rel = Joint.compute_l1_rel_error(prediction_d, target)
        loss = criterion(prediction_d, target["depth_gt"])
        writer1.add_images('pre', prediction_d, global_step=num)

        writer1.add_scalar('rmse', e_rmse, global_step=num)
        writer1.add_scalar("rel", e_rel, global_step=num)
        writer1.add_scalar('loss', loss, global_step=num)

        writer1.add_images('label', label, global_step=num)

        print("ok")
Exemplo n.º 12
0
class tf_recorder:
    def __init__(self, network_name, log_dir):
        os.system('mkdir -p {}'.format(log_dir))
        for i in range(1000):
            self.targ = os.path.join(log_dir, '{}_{}'.format(network_name, i))
            if not os.path.exists(self.targ):
                self.writer = SummaryWriter(self.targ)
                break

    def renew(self, subname):
        self.writer = SummaryWriter('{}_{}'.format(self.targ, subname))
        self.niter = 0

    def add_scalar(self, index, val):
        self.writer.add_scalar(index, val, self.niter)

    def add_scalars(self, index, group_dict):
        self.writer.add_scalar(index, group_dict, self.niter)

    def add_images(self, tag, images):
        self.writer.add_images(tag, images, self.niter)

    def iter(self, tick=1):
        self.niter += tick
Exemplo n.º 13
0
def train_loop(model, loader, test_loader, opt):
    device = torch.device('cuda:{}'.format(opt.cuda))
    print(opt.exp)
    optim = torch.optim.Adam(model.parameters(), 5e-4, betas=(0.5, 0.999))
    writer = SummaryWriter('log/%s' % opt.exp)
    for e in tqdm(range(opt.epochs)):
        l1s, l2s = [], []
        model.train()
        for (x, _) in tqdm(loader):
            x = x.to(device)
            x.requires_grad = False
            if not opt.u:
                out = model(x)
                rec_err = (out - x)**2
                loss = rec_err.mean()
                l1s.append(loss.item())
            else:
                mean, logvar = model(x)
                rec_err = (mean - x)**2
                loss1 = torch.mean(torch.exp(-logvar) * rec_err)
                loss2 = torch.mean(logvar)
                loss = loss1 + loss2
                l1s.append(rec_err.mean().item())
                l2s.append(loss2.item())

            optim.zero_grad()
            loss.backward()
            optim.step()
        auc = test_for_xray(opt, model, test_loader)
        if not opt.u:
            l1s = np.mean(l1s)
            writer.add_scalar('auc', auc, e)
            writer.add_scalar('rec_err', l1s, e)
            writer.add_images('recons',
                              torch.cat((x, out)).cpu() * 0.5 + 0.5, e)
            print('epochs:{}, recon error:{}'.format(e, l1s))
        else:
            l1s = np.mean(l1s)
            l2s = np.mean(l2s)
            writer.add_scalar('auc', auc, e)
            writer.add_scalar('rec_err', l1s, e)
            writer.add_scalar('logvars', l2s, e)
            writer.add_images('recons',
                              torch.cat((x, mean)).cpu() * 0.5 + 0.5, e)
            writer.add_images('vars',
                              torch.cat((x * 0.5 + 0.5, logvar.exp())).cpu(),
                              e)
            print('epochs:{}, recon error:{}, logvars:{}'.format(e, l1s, l2s))

    torch.save(model.state_dict(), './models/{}.pth'.format(opt.exp))
Exemplo n.º 14
0
def benchmark_pruned():
    net = load_pru_mod(after_finetune=True).double()
    net = nn.DataParallel(net)
    net = net.cuda()

    data_loader = nyu_set.use_nyu_data(batch_s=1,
                                       max_len=100,
                                       isBenchmark=True)
    writer1 = SummaryWriter('./gj_dir/benchmark_pru_mod')

    criterion = nn.MSELoss(reduction='mean').cuda()
    net.eval()

    num = 0
    for data, label in data_loader:
        num += 1

        images = Variable(images).double().cuda()
        label = Variable(label).double().cuda()

        # Reshape ...CHW -> XCHW
        shape = images.shape
        prediction_d = net.forward(images)[0]  # 0is depth .1 is confidence

        out_shape = shape[:-3] + prediction_d.shape[-2:]
        prediction_d = prediction_d.reshape(out_shape)
        prediction_d = torch.exp(prediction_d)
        depth = prediction_d.squeeze(-3)
        depth = depth.detach().cpu().numpy().squeeze()

        inv_depth = (1.0 / depth)

        error = criterion(inv_depth, label).item()
        error = torch.sqrt(error / 2)
        writer1.add_scalar('loss', error, global_step=num)
        writer1.add_images('pre', prediction_d, global_step=num)
        writer1.add_images('label', label, global_step=num)
        writer1.add_images('process', inv_depth, global_step=num)
        print("ok")
                tb_logger.add_image("GT heatmaps_{}".format(jj),
                                    hm[0].max(dim=0)[0].unsqueeze(0),
                                    global_step=global_step)
            # add predictions to TB
            tb_logger.add_image("attention maps", att[0],
                                global_step=global_step)

            tb_preds = []
            for jj, d in enumerate(det[0], 1):
                with torch.no_grad():
                    # d = d.clamp(0, 1)
                    d = d.sigmoid()
                    # d = d - d.min()
                    # d /= d.max()
                tb_preds.append(d.unsqueeze(0))
            tb_logger.add_images("detection maps", tb_preds, dataformats="CHW",
                                 global_step=global_step)

            #     tb_logger.add_images("pred_stage_{}".format(jj), tb_preds,
            #                          global_step=global_step,
            #                          dataformats="CHW")
            # add weights and gradients histogram
            for name, param in student.named_parameters():
                tb_logger.add_histogram(name + "_PARAMETERS",
                                        param.cpu().data.numpy(),
                                        global_step)
                if param.grad is not None:
                    tb_logger.add_histogram(name + "_GRADIENTS",
                                            param.grad.cpu().data.numpy(),
                                            global_step)
        # #
        # if global_step % MINIVAL_EVERY_BATCHES == 0:
def main(args):
    # Enable cuda by default
    args.cuda = True

    # Define transforms
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    transform = transforms.Compose(
        [transforms.Resize(args.image_size),
         transforms.ToTensor(), normalize])

    # Create datasets
    datasets = {
        split: RGBDataset(
            os.path.join(args.dataset_root, split),
            seed=123,
            transform=transform,
            image_size=args.image_size,
            truncate_count=args.truncate_count,
        )
        for split in ["train", "val", "test"]
    }

    # Create data loaders
    data_loaders = {
        split: DataLoader(dataset,
                          batch_size=args.batch_size,
                          shuffle=True,
                          num_workers=16)
        for split, dataset in datasets.items()
    }

    device = torch.device("cuda:0" if args.cuda else "cpu")

    # Create model
    net = FeatureNetwork()
    net.to(device)
    net.eval()

    # Generate image features for training images
    train_image_features = []
    train_image_paths = []

    for i, data in enumerate(data_loaders["train"], 0):

        # sample data
        inputs, input_paths = data
        inputs = {key: val.to(device) for key, val in inputs.items()}

        # Extract features
        with torch.no_grad():
            feats = net(inputs["rgb"])  # (bs, 512)
        feats = feats.detach().cpu().numpy()
        train_image_features.append(feats)
        train_image_paths += input_paths["rgb"]

    train_image_features = np.concatenate(train_image_features, axis=0)

    # Generate image features for testing images
    test_image_features = []
    test_image_paths = []

    for i, data in enumerate(data_loaders["test"], 0):

        # sample data
        inputs, input_paths = data
        inputs = {key: val.to(device) for key, val in inputs.items()}

        # Extract features
        with torch.no_grad():
            feats = net(inputs["rgb"])  # (bs, 512)
        feats = feats.detach().cpu().numpy()
        test_image_features.append(feats)
        test_image_paths += input_paths["rgb"]

    test_image_features = np.concatenate(test_image_features,
                                         axis=0)  # (N, 512)

    # ================= Perform clustering ==================
    kmeans = MiniBatchKMeans(
        init="k-means++",
        n_clusters=args.num_clusters,
        batch_size=args.batch_size,
        n_init=10,
        max_no_improvement=20,
        verbose=0,
    )
    save_h5_path = os.path.join(args.save_dir,
                                f"clusters_{args.num_clusters:05d}_data.h5")
    if os.path.isfile(save_h5_path):
        print("========> Loading existing clusters!")
        h5file = h5py.File(os.path.join(save_h5_path), "r")
        train_cluster_centroids = np.array(h5file["cluster_centroids"])
        kmeans.cluster_centers_ = train_cluster_centroids
        train_cluster_assignments = kmeans.predict(
            train_image_features)  # (N, )
        h5file.close()
    else:
        kmeans.fit(train_image_features)
        train_cluster_assignments = kmeans.predict(
            train_image_features)  # (N, )
        train_cluster_centroids = np.copy(
            kmeans.cluster_centers_)  # (num_clusters, 512)

    # Create a dictionary of cluster -> images for visualization
    cluster2image = {}
    if args.visualize_clusters:
        log_dir = os.path.join(
            args.save_dir, f"train_clusters_#clusters{args.num_clusters:05d}")
        tbwriter = SummaryWriter(log_dir=log_dir)

    for i in range(args.num_clusters):
        valid_idxes = np.where(train_cluster_assignments == i)[0]
        valid_image_paths = [train_image_paths[j] for j in valid_idxes]
        # Shuffle and pick only upto 100 images per cluster
        random.shuffle(valid_image_paths)
        # Read the valid images
        valid_images = []
        for path in valid_image_paths[:100]:
            img = cv2.resize(
                np.flip(cv2.imread(path), axis=2),
                (args.image_size, args.image_size),
            )
            valid_images.append(img)
        valid_images = (np.stack(valid_images, axis=0).astype(np.float32) /
                        255.0)  # (K, H, W, C)
        valid_images = torch.Tensor(valid_images).permute(0, 3, 1,
                                                          2).contiguous()
        cluster2image[i] = valid_images
        if args.visualize_clusters:
            # Write the train image clusters to tensorboard
            tbwriter.add_images(f"Cluster #{i:05d}", valid_images, 0)

    h5file = h5py.File(
        os.path.join(args.save_dir,
                     f"clusters_{args.num_clusters:05d}_data.h5"), "a")

    if "cluster_centroids" not in h5file.keys():
        h5file.create_dataset("cluster_centroids",
                              data=train_cluster_centroids)
    for i in range(args.num_clusters):
        if f"cluster_{i}/images" not in h5file.keys():
            h5file.create_dataset(f"cluster_{i}/images", data=cluster2image[i])

    h5file.close()

    if args.visualize_clusters:
        # Dot product of test_image_features with train_cluster_centroids
        test_dot_centroids = np.matmul(
            test_image_features,
            train_cluster_centroids.T)  # (N, num_clusters)
        if args.normalize_embedding:
            test_dot_centroids = (test_dot_centroids + 1.0) / 2.0
        else:
            test_dot_centroids = F.softmax(torch.Tensor(test_dot_centroids),
                                           dim=1).numpy()

        # Find the top-K matching centroids
        topk_matches = np.argpartition(test_dot_centroids, -5,
                                       axis=1)[:, -5:]  # (N, 5)

        # Write the test nearest neighbors to tensorboard
        tbwriter = SummaryWriter(log_dir=os.path.join(
            args.save_dir, f"test_neighbors_#clusters{args.num_clusters:05d}"))
        for i in range(100):
            test_image_path = test_image_paths[i]
            test_image = cv2.resize(cv2.imread(test_image_path),
                                    (args.image_size, args.image_size))
            test_image = np.flip(test_image, axis=2).astype(np.float32) / 255.0
            test_image = torch.Tensor(test_image).permute(2, 0, 1).contiguous()
            topk_clusters = topk_matches[i]
            # Pick some 4 images representative of a cluster
            topk_cluster_images = []
            for k in topk_clusters:
                imgs = cluster2image[k][:4]  # (4, C, H, W)
                if imgs.shape[0] == 0:
                    continue
                elif imgs.shape[0] != 4:
                    imgs_pad = torch.zeros(4 - imgs.shape[0], *imgs.shape[1:])
                    imgs = torch.cat([imgs, imgs_pad], dim=0)
                # Downsample by a factor of 2
                imgs = F.interpolate(imgs, scale_factor=0.5,
                                     mode="bilinear")  # (4, C, H/2, W/2)
                # Reshape to form a grid
                imgs = imgs.permute(1, 0, 2, 3)  # (C, 4, H/2, W/2)
                C, _, Hby2, Wby2 = imgs.shape
                imgs = (imgs.view(C, 2, 2, Hby2, Wby2).permute(
                    0, 1, 3, 2, 4).contiguous().view(C, Hby2 * 2, Wby2 * 2))
                # Draw a red border
                imgs[0, :4, :] = 1.0
                imgs[1, :4, :] = 0.0
                imgs[2, :4, :] = 0.0
                imgs[0, -4:, :] = 1.0
                imgs[1, -4:, :] = 0.0
                imgs[2, -4:, :] = 0.0
                imgs[0, :, :4] = 1.0
                imgs[1, :, :4] = 0.0
                imgs[2, :, :4] = 0.0
                imgs[0, :, -4:] = 1.0
                imgs[1, :, -4:] = 0.0
                imgs[2, :, -4:] = 0.0
                topk_cluster_images.append(imgs)

            vis_img = torch.cat([test_image, *topk_cluster_images], dim=2)
            image_name = f"Test image #{i:04d}"
            for k in topk_clusters:
                score = test_dot_centroids[i, k].item()
                image_name += f"_{score:.3f}"
            tbwriter.add_image(image_name, vis_img, 0)
class Trainer:
    def __init__(self, model, loss, train_loader, test_loader, args):
        self.model = model
        self.args = args
        self.args.start_epoch = 0

        self.train_loader = train_loader
        self.test_loader = test_loader

        # Loss function and Optimizer
        self.loss = loss
        self.optimizer = self.get_optimizer()

        # Tensorboard Writer
        self.summary_writer = SummaryWriter(log_dir=args.summary_dir)
        # Model Loading
        if args.resume:
            self.load_checkpoint(self.args.resume_from)

    def train(self):
        self.model.train()
        for epoch in range(self.args.start_epoch, self.args.num_epochs):
            loss_list = []
            print("epoch {}...".format(epoch))
            for batch_idx, (data, _) in enumerate(tqdm(self.train_loader)):
                if self.args.cuda:
                    data = data.cuda()
                data = Variable(data)
                self.optimizer.zero_grad()
                recon_batch, mu, logvar = self.model(data)
                loss = self.loss(recon_batch, data, mu, logvar)
                loss.backward()
                self.optimizer.step()
                loss_list.append(loss.item())

            print("epoch {}: - loss: {}".format(epoch, np.mean(loss_list)))
            new_lr = self.adjust_learning_rate(epoch)
            print('learning rate:', new_lr)

            self.summary_writer.add_scalar('training/loss', np.mean(loss_list),
                                           epoch)
            self.summary_writer.add_scalar('training/learning_rate', new_lr,
                                           epoch)
            self.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': self.model.state_dict(),
                'optimizer': self.optimizer.state_dict(),
            })
            if epoch % self.args.test_every == 0:
                self.test(epoch)

    def test(self, cur_epoch):
        print('testing...')
        with torch.no_grad():
            self.model.eval()
            test_loss = 0
            for i, (data, _) in enumerate(self.test_loader):
                if self.args.cuda:
                    data = data.cuda()
                recon_batch, mu, logvar = self.model(data)
                test_loss += self.loss(recon_batch, data, mu, logvar).item()
                _, indices = recon_batch.max(1)
                indices.data = indices.data.float() / 255
                if i == 0:
                    n = min(data.size(0), 8)
                    comparison = torch.cat(
                        [data[:n], indices.view(-1, 3, 32, 32)[:n]])
                    self.summary_writer.add_images('testing_set/image',
                                                   comparison, cur_epoch)
                    comparison = torchvision.utils.make_grid(comparison,
                                                             nrow=6)
                    torchvision.utils.save_image(comparison.cpu(),
                                                 'results/reconstruction_' +
                                                 str(cur_epoch) + '.png',
                                                 nrow=8)

        test_loss /= len(self.test_loader.dataset)
        print('====> Test set loss: {:.4f}'.format(test_loss))
        self.summary_writer.add_scalar('testing/loss', test_loss, cur_epoch)
        self.model.train()

    def test_on_trainings_set(self):
        print('testing...')
        with torch.no_grad():
            self.model.eval()
            test_loss = 0
            for i, (data, _) in enumerate(self.train_loader):
                if self.args.cuda:
                    data = data.cuda()
                recon_batch, mu, logvar = self.model(data)
                test_loss += self.loss(recon_batch, data, mu, logvar).item()
                _, indices = recon_batch.max(1)
                indices.data = indices.data.float() / 255
                if i % 50 == 0:
                    n = min(data.size(0), 8)
                    comparison = torch.cat(
                        [data[:n], indices.view(-1, 3, 32, 32)[:n]])
                    self.summary_writer.add_images('training_set/image',
                                                   comparison, i)

        test_loss /= len(self.test_loader.dataset)
        print('====> Test on training set loss: {:.4f}'.format(test_loss))
        self.model.train()

    def get_optimizer(self):
        return optim.Adam(self.model.parameters(),
                          lr=self.args.learning_rate,
                          weight_decay=self.args.weight_decay)

    def adjust_learning_rate(self, epoch):
        """Sets the learning rate to the initial LR multiplied by 0.98 every epoch"""
        learning_rate = self.args.learning_rate * (
            self.args.learning_rate_decay**epoch)
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = learning_rate
        return learning_rate

    def save_checkpoint(self,
                        state,
                        is_best=False,
                        filename='checkpoint.pth.tar'):
        '''
        a function to save checkpoint of the training
        :param state: {'epoch': cur_epoch + 1, 'state_dict': self.model.state_dict(),
                            'optimizer': self.optimizer.state_dict()}
        :param is_best: boolean to save the checkpoint aside if it has the best score so far
        :param filename: the name of the saved file
        '''
        torch.save(state, self.args.checkpoint_dir + filename)
        if is_best:
            shutil.copyfile(self.args.checkpoint_dir + filename,
                            self.args.checkpoint_dir + 'model_best.pth.tar')

    def load_checkpoint(self, filename):
        filename = self.args.checkpoint_dir + filename
        try:
            print("Loading checkpoint '{}'".format(filename))
            checkpoint = torch.load(filename)
            self.args.start_epoch = checkpoint['epoch']
            self.model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            print("Checkpoint loaded successfully from '{}' at (epoch {})\n".
                  format(self.args.checkpoint_dir, checkpoint['epoch']))
        except:
            print("No checkpoint exists from '{}'. Skipping...\n".format(
                self.args.checkpoint_dir))
Exemplo n.º 18
0
def train_pru_mod(epoch=100, batch=4, lr=0.001):

    #net = load_t_net().double()
    net = load_pru_mod(after_finetune=True).double()
    net = nn.DataParallel(net)
    net = net.cuda()

    train_Data = nyu_set.use_nyu_data(batch_s=batch,
                                      max_len=160,
                                      isBenchmark=False)
    writer1 = SummaryWriter('./gj_dir/train_pru_mod')

    criterion = nn.MSELoss(reduction='mean').cuda()
    Joint = JointLoss(opt=None).double().cuda()
    s_loss = ts_loss.SSIM().cuda()
    optimizer = optim.Adam(net.parameters(), lr=lr)

    net.train()
    import time
    for epoch in range(epoch):
        time_start = time.time()
        batch_size = batch

        for i, data in enumerate(train_Data):
            images, depths = data
            # images = autograd.Variable(inputs.cuda(), requires_grad=False)
            images = Variable(images).double().cuda()
            depths = Variable(depths).double().cuda()

            # labels = labels.to(device).double()

            optimizer.zero_grad()
            # debug_img = transforms.ToPILImage()(images[0,:,:,:].float().cpu())
            # debug_img.save("debug.jpg")

            output_net = net(images)[0].double()

            # loss1 = 1 - s_loss.forward(output_s_features, T_mid_feature[0])
            # loss2 = criterion(output_s_depth,output_t)
            loss1 = criterion(output_net, depths)
            loss2 = Joint.LaplacianSmoothnessLoss(output_net, images)
            loss3 = Joint.compute_image_aware_2nd_smoothness_cost(
                output_net, images)
            #loss4 = Joint.compute_image_aware_1st_smoothness_cost(output_net,images)
            loss4 = 1 - s_loss.forward(output_net, depths)
            loss = loss1 * 10 + loss2 + loss3 + loss4

            loss.backward()
            optimizer.step()

            print('[%d, %5d] loss: %.4f  A:%.4f  B:%.4f C:%.4f D:%.4f' %
                  (epoch + 1, (i + 1) * batch_size, loss.item(), loss1.item(),
                   loss2.item(), loss3.item(), loss4.item()))

            writer1.add_scalar('loss',
                               loss.item(),
                               global_step=(epoch + 1) * batch_size + i)
            writer1.add_scalar('loss2',
                               loss2.item(),
                               global_step=(epoch + 1) * batch_size + i)
        #debug_img = transforms.ToPILImage()(output_net)
        writer1.add_images('pre', output_net, global_step=epoch)
        shape = images.shape

        dep = torch.exp(output_net)

        dep = dep.detach().cpu().numpy()
        inv_dep = 1.0 / dep * 255

        writer1.add_images('pro-dep', inv_dep, global_step=epoch)

        writer1.add_images('labels', depths, global_step=epoch)

        torch.save(net.module, "./gj_dir/after_nyu.pth.tar")
        time_end = time.time()
        print('Time cost:', time_end - time_start, "s")

    print('Finished Training')
class PredictionCallback(tf.keras.callbacks.Callback):
    """Predictions logged using tensorboardX"""

    def __init__(self, model, logdir, val_generator, scaled_mask, binary_threshold=0.5, update_freq=1):
        super(PredictionCallback, self).__init__()
        self.val_generator = val_generator
        self.writer = SummaryWriter(logdir=logdir)
        self.scaled_mask = scaled_mask
        self.binary_threshold = binary_threshold
        self._model = model
        self.num_classes = self._model.output.shape.as_list()[-1]
        self.update_freq = update_freq

    def on_epoch_end(self, epoch, logs={}):

        if epoch == -1:
            epoch = 0
        else:
            epoch = epoch + 1

        if epoch % self.update_freq == 0:
            logger.debug("logging images to tensorboard, epoch=%d" % epoch)

            for input_batch, target_batch in self.val_generator:
                input_batch = input_batch.numpy()
                target_batch = target_batch.numpy()
                break

            # input_batch, target_batch = next(iter(self.val_generator.as_numpy_iterator()))

            if self.scaled_mask:
                target_batch = np.expand_dims(target_batch, axis=-1)

            # predict
            pred_batch = self._model.predict_on_batch(input_batch).numpy()

            predictions_on_inputs = masks.get_colored_segmentation_mask(pred_batch, self.num_classes, images=input_batch, binary_threshold=self.binary_threshold)
            self.writer.add_images('inputs/with_predictions', predictions_on_inputs, dataformats='NHWC', global_step=epoch)

            targets_on_inputs = masks.get_colored_segmentation_mask(target_batch, self.num_classes, images=input_batch, binary_threshold=self.binary_threshold)
            self.writer.add_images('inputs/with_targets', targets_on_inputs, dataformats='NHWC', global_step=epoch)

            targets_rgb = masks.get_colored_segmentation_mask(target_batch, self.num_classes, binary_threshold=self.binary_threshold, alpha=1.0)
            self.writer.add_images('targets/rgb', targets_rgb, dataformats='NHWC', global_step=epoch)

            pred_rgb = masks.get_colored_segmentation_mask(pred_batch, self.num_classes, binary_threshold=self.binary_threshold, alpha=1.0)
            self.writer.add_images('predictions/rgb', pred_rgb, dataformats='NHWC', global_step=epoch)

            if not self.scaled_mask:
                pred_batch = np.argmax(pred_batch, axis=-1).astype(np.float32)
                target_batch = np.argmax(target_batch, axis=-1).astype(np.float32)

                # reshape, that add_images works
                pred_batch = np.expand_dims(pred_batch, axis=-1)
                target_batch = np.expand_dims(target_batch, axis=-1)
            else:
                pred_batch[pred_batch > self.binary_threshold] = 1.0
                pred_batch[pred_batch <= self.binary_threshold] = 0.0

            self.writer.add_images('inputs', input_batch, dataformats='NHWC', global_step=epoch)
            self.writer.add_images('targets', target_batch, dataformats='NHWC', global_step=epoch)
            self.writer.add_images('predictions', pred_batch, dataformats='NHWC', global_step=epoch)
Exemplo n.º 20
0
                # Weighted sum of each loss into a total loss function
                reconstruction_loss = (1 - Lambda) * reconstruction_loss
                lstd_loss = Lambda * lstd_loss
                classifier_loss = (1 - Lambda) * classifier_loss

                total_loss = reconstruction_loss + lstd_loss + classifier_loss

                # Backprop Total loss (VAE + Classifer) and update VAE if training
                if phase == 'train':
                    # Calculate Total loss and Backprop through network
                    total_loss.backward()
                    optim.step()

                if idx % 1000 == 0:
                    writer.add_images("Images", inputs, global_step=epoch)
                    writer.add_images("Reconstructions",
                                      recon_images,
                                      global_step=epoch)
                    # writer.add_graph(VAE(), (inputs, labels), global_step=epoch)

            if idx % 100 == 0 and args.display:
                print(
                    f'{phase} Batch Loss: {running_cls_loss}| Acc: {running_corrects / batch_size}'
                )

            # (Classifer) Average Loss and Accuracy for current batch
            running_cls_loss += classifier_loss.item() * inputs.size(0)
            running_corrects += torch.sum(predicted == labels.data)

            # (VAE) Average Losses for current batch
Exemplo n.º 21
0
def train(
        backbone,
        root_dir,
        train_index_fp,
        pretrain_model,
        optimizer,
        epochs=50,
        lr=0.001,
        wd=5e-4,
        momentum=0.9,
        batch_size=4,
        ctx=mx.cpu(),
        verbose_step=5,
        output_dir='ckpt',
):
    output_dir = os.path.join(output_dir, backbone)
    os.makedirs(output_dir, exist_ok=True)
    num_kernels = 3
    dataset = StdDataset(root_dir=root_dir,
                         train_idx_fp=train_index_fp,
                         num_kernels=num_kernels - 1)
    if not isinstance(ctx, (list, tuple)):
        ctx = [ctx]
    batch_size = batch_size * len(ctx)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
    net = PSENet(base_net_name=backbone,
                 num_kernels=num_kernels,
                 ctx=ctx,
                 pretrained=True)
    # initial params
    net.initialize(mx.init.Xavier(), ctx=ctx)
    net.collect_params("extra_.*_weight|decoder_.*_weight").initialize(
        mx.init.Xavier(), ctx=ctx, force_reinit=True)
    net.collect_params("extra_.*_bias|decoder_.*_bias").initialize(
        mx.init.Zero(), ctx=ctx, force_reinit=True)

    if pretrain_model is not None:
        net.load_parameters(pretrain_model,
                            ctx=ctx,
                            allow_missing=True,
                            ignore_extra=True)

    # pse_loss = DiceLoss(lam=0.7, num_kernels=num_kernels)
    pse_loss = DiceLoss_with_OHEM(lam=0.7,
                                  num_kernels=num_kernels,
                                  debug=False)

    # lr_scheduler = ls.PolyScheduler(
    #     max_update=icdar_loader.length * epochs // batch_size, base_lr=lr
    # )
    max_update = len(dataset) * epochs // batch_size
    lr_scheduler = ls.MultiFactorScheduler(
        base_lr=lr, step=[max_update // 3, max_update * 2 // 3], factor=0.1)

    optimizer_params = {
        'learning_rate': lr,
        'wd': wd,
        'momentum': momentum,
        'lr_scheduler': lr_scheduler,
    }
    if optimizer.lower() == 'adam':
        optimizer_params.pop('momentum')

    trainer = Trainer(net.collect_params(),
                      optimizer=optimizer,
                      optimizer_params=optimizer_params)
    summary_writer = SummaryWriter(output_dir)
    for e in range(epochs):
        cumulative_loss = 0

        num_batches = 0
        for i, item in enumerate(loader):
            item_ctxs = [split_and_load(field, ctx) for field in item]
            loss_list = []
            for im, gt_text, gt_kernels, training_masks, ori_img in zip(
                    *item_ctxs):
                gt_text = gt_text[:, ::4, ::4]
                gt_kernels = gt_kernels[:, :, ::4, ::4]
                training_masks = training_masks[:, ::4, ::4]

                with autograd.record():
                    kernels_pred = net(im)  # 第0个是对complete text的预测
                    loss = pse_loss(gt_text, gt_kernels, kernels_pred,
                                    training_masks)
                    loss_list.append(loss)
            mean_loss = []
            for loss in loss_list:
                loss.backward()
                mean_loss.append(mx.nd.mean(to_cpu(loss)).asscalar())
            mean_loss = np.mean(mean_loss)
            trainer.step(batch_size)

            if i % verbose_step == 0:
                global_steps = dataset.length * e + i * batch_size
                summary_writer.add_scalar('loss', mean_loss, global_steps)
                summary_writer.add_scalar(
                    'c_loss',
                    mx.nd.mean(to_cpu(pse_loss.C_loss)).asscalar(),
                    global_steps,
                )
                summary_writer.add_scalar(
                    'kernel_loss',
                    mx.nd.mean(to_cpu(pse_loss.kernel_loss)).asscalar(),
                    global_steps,
                )
                summary_writer.add_scalar('pixel_accuracy', pse_loss.pixel_acc,
                                          global_steps)
            if i % 1 == 0:
                logger.info(
                    "step: {}, lr: {}, "
                    "loss: {}, score_loss: {}, kernel_loss: {}, pixel_acc: {}, kernel_acc: {}"
                    .format(
                        i * batch_size,
                        trainer.learning_rate,
                        mean_loss,
                        mx.nd.mean(to_cpu(pse_loss.C_loss)).asscalar(),
                        mx.nd.mean(to_cpu(pse_loss.kernel_loss)).asscalar(),
                        pse_loss.pixel_acc,
                        pse_loss.kernel_acc,
                    ))
            cumulative_loss += mean_loss
            num_batches += 1
        summary_writer.add_scalar('mean_loss_per_epoch',
                                  cumulative_loss / num_batches, global_steps)
        logger.info("Epoch {}, mean loss: {}\n".format(
            e, cumulative_loss / num_batches))
        net.save_parameters(
            os.path.join(output_dir, model_fn_prefix(backbone, e)))

    summary_writer.add_image('complete_gt', to_cpu(gt_text[0:1, :, :]),
                             global_steps)
    summary_writer.add_image('complete_pred',
                             to_cpu(kernels_pred[0:1, 0, :, :]), global_steps)
    summary_writer.add_images(
        'kernels_gt',
        to_cpu(gt_kernels[0:1, :, :, :]).reshape(-1, 1, 0, 0),
        global_steps,
    )
    summary_writer.add_images(
        'kernels_pred',
        to_cpu(kernels_pred[0:1, 1:, :, :]).reshape(-1, 1, 0, 0),
        global_steps,
    )

    summary_writer.close()
Exemplo n.º 22
0
class Train:

    def __init__(self):
        self.epoch = 0
        self.step = 0

    def train(self):

        weight = torch.ones(2)
        criterion = criterion_CEloss(weight.cuda())
        optimizer = torch.optim.Adam(self.model.parameters(),lr=0.001,betas=(0.9,0.999))
        lambda_lr = lambda epoch:(float)(self.args.max_epochs*len(self.dataset_train_loader)-self.step)/(float)(self.args.max_epochs*len(self.dataset_train_loader))
        model_lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda_lr)

        f_loss = open(pjoin(self.checkpoint_save,"loss.csv"),'w')
        loss_writer = csv.writer(f_loss)

        self.visual_writer = SummaryWriter(os.path.join(self.checkpoint_save,'logs'))

        loss_item = []

        max_step = self.args.max_epochs * len(self.dataset_train_loader)
        _,w,h = self.dataset_test.get_random_image()[0].shape
        img_tbx =  np.zeros((max_step//self.args.step_test, 3, w*2, h*2), dtype=np.uint8)

        while self.epoch < self.args.max_epochs:

            for step,(inputs_train,mask_train) in enumerate(tqdm(self.dataset_train_loader)):
                self.model.train()
                inputs_train = inputs_train.cuda()
                mask_train = mask_train.cuda()
                output_train = self.model(inputs_train)
                optimizer.zero_grad()
                self.loss = criterion(output_train, mask_train[:,0])
                loss_item.append(self.loss)
                self.loss.backward()
                optimizer.step()
                self.step += 1
                loss_writer.writerow([self.step,self.loss.item()])
                self.visual_writer.add_scalar('loss',self.loss.item(),self.step)

                if self.args.step_test>0 and self.step % self.args.step_test == 0:
                    print('testing...')
                    self.model.eval()
                    self.test(img_tbx)

            print('Loss for Epoch {}:{:.03f}'.format(self.epoch, sum(loss_item)/len(self.dataset_train_loader)))
            loss_item.clear()
            model_lr_scheduler.step()
            self.epoch += 1
            if self.args.epoch_save>0 and self.epoch % self.args.epoch_save == 0:
                self.checkpoint()

        self.visual_writer.add_images('cd_test',img_tbx,0, dataformats='NCHW')
        f_loss.close()
        self.visual_writer.close()

    def test(self,img_tbx):

        _, _, w_r, h_r = img_tbx.shape
        w_r //= 2
        h_r //= 2
        input, mask_gt = self.dataset_test.get_random_image()

        input = input.view(1, -1, h_r, w_r)
        input = input.cuda()
        output = self.model(input)

        input = input[0].cpu().data
        img_t0 = input[0:3, :, :]
        img_t1 = input[3:6, :, :]
        img_t0 = (img_t0 + 1) * 128
        img_t1 = (img_t1 + 1) * 128
        output = output[0].cpu().data
        mask_pred = np.where(F.softmax(output[0:2, :, :], dim=0)[0] > 0.5, 0, 255)
        mask_gt = np.squeeze(np.where(mask_gt == True, 255, 0), axis=0)
        self.store_result(img_t0, img_t1, mask_gt, mask_pred,img_tbx)

    def store_result(self, t0, t1, mask_gt, mask_pred, img_save):

        _, _, w, h = img_save.shape
        w //=2
        h //=2
        i = self.step//self.args.step_test - 1
        img_save[i, :, 0:w, 0:h] = t0.numpy().astype(np.uint8)
        img_save[i, :, 0:w, h:2 * h] = t1.numpy().astype(np.uint8)
        img_save[i, :, w:2 * w, 0:h] = np.transpose(cv2.cvtColor(mask_gt.astype(np.uint8), cv2.COLOR_GRAY2RGB),(2,0,1)).astype(np.uint8)
        img_save[i, :, w:2 * w, h:2 * h] = np.transpose(cv2.cvtColor(mask_pred.astype(np.uint8), cv2.COLOR_GRAY2RGB),(2,0,1)).astype(np.uint8)

        #img_save = np.transpose(img_save, (1, 0, 2))

    def checkpoint(self):

        filename = '{:08d}.pth'.format(self.step)
        cp_path = pjoin(self.checkpoint_save,'checkpointdir')
        if not os.path.exists(cp_path):
            os.makedirs(cp_path)
        torch.save(self.model.state_dict(),pjoin(cp_path,filename))
        print("Net Parameters in step:{:08d} were saved.".format(self.step))

    def run(self):

        self.model = TANet(self.args.encoder_arch, self.args.local_kernel_size, self.args.attn_stride,
                           self.args.attn_padding, self.args.attn_groups, self.args.drtam, self.args.refinement)

        if self.args.drtam:
            print('Dynamic Receptive Temporal Attention Network (DR-TANet)')
        else:
            print('Temporal Attention Network (TANet)')

        print('Encoder:' + self.args.encoder_arch)
        if self.args.refinement:
            print('Adding refinement...')

        if self.args.multi_gpu:
            self.model = nn.DataParallel(self.model).cuda()
        else:
            self.model = self.model.cuda()
        self.train()
    def fit(self,
            train_loader_fn: Callable,
            epochs: int = 2,
            lr: float = 1e-3,
            n_critic: int = 5,
            disk_backup_filename: str = "dumped_weights.bin"):
        """
        Trains this WGAN.
        :param train_loader_fn: loader-returning function to generate training data.
        :param epochs: number of epochs. An epoch is a full pass over the train loader.
        :param lr: learning rate.
        :param n_critic: number of steps to train the discriminator (a.k.a. critic) per each training
            step of the generator. In WGANs, it is OK to make it large.
        :param disk_backup_filename: filename to dump trainable parameters. Dumping is done once per epoch.
        """
        Util.set_param_requires_grad(self.generator, True)
        Util.set_param_requires_grad(self.discriminator, True)
        if not self.params:
            self.random_init()
            self.save_params()
            self.save_params_to_disk(disk_backup_filename)
        else:
            self.restore_params()

        g_optimizer = torch.optim.RMSprop(self.generator.parameters(), lr=lr)
        d_optimizer = torch.optim.RMSprop(self.discriminator.parameters(),
                                          lr=lr)
        self.generator.train()
        self.discriminator.train()
        writer = SummaryWriter("gan_training")
        batch_index = 0

        for epoch in range(epochs):
            data_sampler = iter(train_loader_fn())
            while True:
                # preload real batches
                real_batches = []
                for i in range(n_critic):
                    try:
                        real_data, _ = next(data_sampler)
                        real_batches.append(real_data)
                    except StopIteration:
                        # for simplicity, omitting the last incomplete sequence of batches
                        break
                if len(real_batches) != n_critic:
                    # next epoch
                    break

                batch_size = real_batches[0].shape[0]

                # train
                d_optimizer.zero_grad()
                for i in range(n_critic):
                    real_data = real_batches[i]
                    fake_data = self.generator(
                        Util.conditional_to_cuda(
                            torch.randn(batch_size, self.latent_dim)))
                    loss1 = self.discriminator(real_data).mean()
                    loss2 = self.discriminator(fake_data).mean()
                    discriminator_loss = -(loss1 - loss2)
                    discriminator_loss.backward()
                    d_optimizer.step()

                g_optimizer.zero_grad()
                fake_data = self.generator(
                    Util.conditional_to_cuda(
                        torch.randn(batch_size, self.latent_dim)))
                generator_loss = -self.discriminator(fake_data).mean()
                #generator_loss = (fake_data - real_data).abs().mean()
                generator_loss.backward()
                g_optimizer.step()

                # eval
                with torch.no_grad():
                    writer.add_scalar("discriminator_loss",
                                      discriminator_loss.detach(), batch_index)
                    writer.add_scalar("generator_loss",
                                      generator_loss.detach(), batch_index)
                    writer.add_scalar("epoch", epoch, batch_index)
                    if batch_index % 20 == 0:
                        writer.add_images("generated_batch",
                                          fake_data.detach().clamp(-1, 1),
                                          batch_index)
                        writer.add_images("real_batch", real_data.detach(),
                                          batch_index)
                batch_index += 1
Exemplo n.º 24
0
def fit(model,
        device,
        data_path,
        epochs=5,
        batch_size=2,
        lr=0.001,
        val_percent=0.1):
    # get image/mask data
    dataset = IrisDataset(*data_path)
    n_valid = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_valid
    train_ds, valid_ds = random_split(dataset, [n_train, n_valid])
    train_dl = DataLoader(train_ds,
                          batch_size=batch_size,
                          shuffle=True,
                          num_workers=4,
                          pin_memory=True)
    valid_dl = DataLoader(valid_ds,
                          batch_size=batch_size,
                          shuffle=False,
                          num_workers=4,
                          pin_memory=True,
                          drop_last=True)

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_valid}
        Device:          {device.type}
    ''')

    writer = SummaryWriter()
    global_step = 0

    optimizer = optim.RMSprop(model.parameters(),
                              lr=lr,
                              weight_decay=1e-8,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer)
    loss_func = nn.CrossEntropyLoss(
    ) if model.n_classes > 1 else nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='image') as pbar:
            model.train()
            for image, mask in train_dl:
                loss, _ = loss_batch(model, device, loss_func, image, mask,
                                     optimizer)

                writer.add_scalar('Loss/train', loss, global_step)
                pbar.set_postfix(**{'loss (batch)': loss})
                pbar.update(image.shape[0])

                global_step += 1
                if global_step % (len(dataset) // (10 * batch_size)) == 0:
                    for tag, value in model.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)

                    model.eval()
                    with torch.no_grad():
                        losses, nums = zip(*[
                            loss_batch(model, device, loss_func, image, mask)
                            for image, mask in valid_dl
                        ])
                    val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
                    scheduler.step(val_loss)

                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)
                    if model.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_loss))
                        writer.add_scalar('Loss/test', val_loss, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_loss))
                        writer.add_scalar('Dice/test', val_loss, global_step)

                    writer.add_images('images', image, global_step)
                    if model.n_classes == 1:
                        writer.add_images('masks/true', mask, global_step)
                        image = image.to(device=device, dtype=torch.float32)
                        writer.add_images('masks/pred',
                                          torch.sigmoid(model(image)) > 0.5,
                                          global_step)

    writer.close()
Exemplo n.º 25
0
def train(opt):
    date = datetime.date(datetime.now())
    logs = '../logs/'
    logdir = os.path.join(logs,str(date))
    if not os.path.exists(logdir):
        os.mkdir(logdir)
    else:
        logdir = logdir+"_"+str(np.random.randint(0,1000))
        os.mkdir(logdir)
    
    train_data = AllInOneData(opt.train_path,set='train',transforms=transforms.Compose([Normalizer(),Resizer()]))
    train_generator = torch.utils.data.DataLoader(train_data,batch_size=opt.batch_size,shuffle=True,num_workers=8,
                                                    collate_fn=collater,drop_last=True)

    valid_data = AllInOneData(opt.train_path,set='validation',transforms=transforms.Compose([Normalizer(),Resizer()]))
    valid_generator = torch.utils.data.DataLoader(valid_data,batch_size=opt.batch_size,shuffle=False,num_workers=8,
                                                    collate_fn=collater,drop_last=True)
    
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    model = EfficientDetMultiBackbone(opt.train_path,compound_coef=0,heads=opt.heads)
    model.to(device)

    min_val_loss = 10e5
    
    if opt.optim == 'Adam':
        optimizer = torch.optim.AdamW(model.parameters(),lr=opt.lr)
    else:
        optimizer = torch.optim.SGD(model.parameters(),lr=opt.lr,momentum = opt.momentum,nesterov=True)

    scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, opt.lr, total_steps=None, epochs=opt.epochs,
                                                    steps_per_epoch=len(train_generator), pct_start=0.1, anneal_strategy='cos',
                                                    cycle_momentum=True, base_momentum=0.85, max_momentum=0.95, 
                                                    div_factor=25.0, final_div_factor=1000.0, last_epoch=-1)

    criterion = MTLoss(heads = opt.heads, device = device)
    
    print('Model is successfully initiated')
    print(f'Targets are {opt.heads}.')
    verb_loss = 0
    writer = SummaryWriter(logdir=logdir,filename_suffix=f'Train_{"_".join(opt.heads)}',comment='try1')
    
    for epoch in range(opt.epochs):
        model.train()
        Losses = {k:[] for k in opt.heads}
        description = f'Epoch:{epoch}| Total Loss:{verb_loss}'
        progress_bar = tqdm(train_generator,desc = description)
        Total_loss = []
        for sample in progress_bar:
                        
            imgs = sample['img'].to(device)
            gt_person_bbox = sample['person_bbox'].to(device)
            gt_face_bbox = sample['face_bbox'].to(device)
            gt_pose = sample['pose'].to(device)
            gt_face_landmarks = sample['face_landmarks'].to(device)
            gt_age = sample['age'].to(device)
            gt_race = sample['race'].to(device)
            gt_gender = sample['gender'].to(device)
            gt_skin = sample['skin'].to(device)
            gt_emotions = sample['emotion'].to(device)        

            out = model(imgs)
            annot = {'person':gt_person_bbox,'gender':gt_gender,
                     'face':gt_face_bbox,'emotions':gt_emotions,
                     'face_landmarks':gt_face_landmarks,
                     'pose':gt_pose}
            
            losses, lm_mask = criterion(out,annot,out['anchors'])
            loss = torch.zeros(1).to(device)
            loss = torch.sum(torch.cat(list(losses.values())))
            loss.backward()
            optimizer.step()
            scheduler.step() 

            verb_loss = loss.detach().cpu().numpy()
            Total_loss.append(verb_loss)
            description = f'Epoch:{epoch}| Total Loss:{verb_loss}|'
            for k,v in losses.items():
                Losses[k].append(v.detach().cpu().numpy())
                description+=f'{k}:{round(np.mean(Losses[k]),1)}|'
            progress_bar.set_description(description)
            optimizer.zero_grad()
        
        writer.add_scalar('Train/Total',round(np.mean(Total_loss),2),epoch)
        for k in Losses.keys():
            writer.add_scalar(f"Train/{k}",round(np.mean(Losses[k]),2),epoch)
        
        if epoch%opt.valid_step==0:
            im = (imgs[0]+1)/2*255
            
            regressBoxes = BBoxTransform()
            clipBoxes = ClipBoxes()
            pp = postprocess(imgs,
                  out['anchors'], out['person'], out['gender'],
                  regressBoxes, clipBoxes,
                  0.4, 0.4)
            
            writer.add_image_with_boxes('Train/Box_prediction',im,pp[0]['rois'],epoch)
            img2 = out['face_landmarks']
            if img2.shape[1]>3:
                img2 = img2.sum(axis=1).unsqueeze(1)*255
                lm_mask = lm_mask.sum(axis=1).unsqueeze(1)*255
            writer.add_images('Train/landmarks_prediction',img2,epoch)
            writer.add_images('Train/landmark target', lm_mask,epoch)
            
            #VALIDATION STEPS
            model.eval()
            with torch.no_grad():
                valid_Losses = {k:[] for k in opt.heads}

                val_description = f'Validation| Total Loss:{verb_loss}'
                progress_bar = tqdm(valid_generator,desc = val_description)
                Total_loss = []
                for sample in progress_bar:   
                    imgs = sample['img'].to(device)
                    gt_person_bbox = sample['person_bbox'].to(device)
                    gt_face_bbox = sample['face_bbox'].to(device)
                    gt_pose = sample['pose'].to(device)
                    gt_face_landmarks = sample['face_landmarks'].to(device)
                    gt_age = sample['age'].to(device)
                    gt_race = sample['race'].to(device)
                    gt_gender = sample['gender'].to(device)
                    gt_skin = sample['skin'].to(device)
                    gt_emotions = sample['emotion'].to(device)
                    out = model(imgs)
                    annot = {'person':gt_person_bbox,'gender':gt_gender,
                     'face':gt_face_bbox,'emotions':gt_emotions,
                     'face_landmarks':gt_face_landmarks,
                     'pose':gt_pose}

                    losses, lm_mask = criterion(out,annot,out['anchors'])

                    loss = torch.zeros(1).to(device)
                    loss = torch.sum(torch.cat(list(losses.values())))
                    verb_loss = loss.detach().cpu().numpy()
                    Total_loss.append(verb_loss)
                    val_description = f'Validation| Total Loss:{verb_loss}|'
                    for k,v in losses.items():
                        valid_Losses[k].append(v.detach().cpu().numpy())
                        val_description+=f'{k}:{round(np.mean(valid_Losses[k]),1)}|'
                    progress_bar.set_description(val_description)

                writer.add_scalar('Validation/Total',round(np.mean(Total_loss),2),epoch)
                for k in valid_Losses.keys():
                    writer.add_scalar(f"Validation/{k}",round(np.mean(valid_Losses[k]),2),epoch)

                im = (imgs[0]+1)/2*255
                
                regressBoxes = BBoxTransform()
                clipBoxes = ClipBoxes()
                pp = postprocess(imgs,
                  out['anchors'], out['person'], out['gender'],
                  regressBoxes, clipBoxes,
                  0.4, 0.4)

                writer.add_image_with_boxes('Validation/Box_prediction',im,pp[0]['rois'],epoch)
                
                img2 = out['face_landmarks']
                if img2.shape[1]>3:
                    img2 = img2.sum(axis=1).unsqueeze(1)*255
                    lm_mask = lm_mask.sum(axis=1).unsqueeze(1)*255
                writer.add_images('Validation/landmarks_prediction',img2,epoch)
                writer.add_images('Validation/landmark target', lm_mask,epoch)

                if verb_loss<min_val_loss:
                    print("The model improved and checkpoint is saved.")
                    torch.save(model.state_dict(),f'{logdir}/{opt.save_name.split(".pt")[0]}_best_epoch_{epoch}.pt')
                    min_val_loss = verb_loss
                

        if epoch%100==0:
            torch.save(model.state_dict(),f'{logdir}/{opt.save_name.split(".pt")[0]}_epoch_{epoch}.pt')
    torch.save(model.state_dict(),f'{logdir}/{opt.save_name.split(".pt")[0]}_last.pt')
    writer.close()
Exemplo n.º 26
0
def train(args):
    writer = SummaryWriter(comment=args.writer)

    # data loader setting, train and evaluation
    data_loader = get_loader(args.dataset)
    data_path = get_data_path(args.dataset)
    t_loader = data_loader(data_path, split='train', img_size=(args.img_rows, args.img_cols), img_norm=args.img_norm)
    v_loader = data_loader(data_path, split='test', img_size=(args.img_rows, args.img_cols), img_norm=args.img_norm)

    trainloader = data.DataLoader(t_loader, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers)
    evalloader = data.DataLoader(v_loader, batch_size=args.batch_size, num_workers=args.num_workers)
    print("Finish Loader Setup")

    # Setup Model and load pretrained model
    model_name = args.arch_RGB
    # print(model_name)
    model = get_model(model_name, True)  # vgg_16
    if args.pretrain:  # True by default
        if args.input == 'rgb':  # only for rgb we have pretrain option
            state = get_premodel(model, args.state_name)
            model.load_state_dict(state)
            model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
        elif args.input == 'd':  # for d, load from result from...
            print("Load training model: {}_{}_{}_{}_best.pkl".format(args.arch_RGB, args.dataset, args.loss, 1))
            checkpoint = torch.load(pjoin(args.model_savepath_pretrain,
                                          "{}_{}_{}_{}_best.pkl".format(args.arch_RGB, args.dataset, args.loss, 1)))
            # model.load_state_dict(load_resume_state_dict(model, checkpoint['model_D_state']))   
            model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
            model.load_state_dict(checkpoint['model_D_state'])
    else:
        model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))

    # model = torch.nn.DataParallel(model, device_ids=range(torch.cuda.device_count()))
    # model_RGB = DataParallelWithCallback(model_RGB, device_ids=range(torch.cuda.device_count()))
    model.cuda()
    print("Finish model setup with model %s and state_dict %s" % (args.arch_RGB, args.state_name))

    # optimizers and lr-decay setting
    if args.pretrain:  # True by default
        optimizer_RGB = torch.optim.RMSprop(model.parameters(), lr=0.25 * args.l_rate)
        scheduler_RGB = torch.optim.lr_scheduler.MultiStepLR(optimizer_RGB, milestones=[1, 2, 4, 8], gamma=0.5)
    else:
        optimizer_RGB = torch.optim.RMSprop(model.parameters(), lr=args.l_rate)
        scheduler_RGB = torch.optim.lr_scheduler.MultiStepLR(optimizer_RGB, milestones=[1, 3, 5, 8, 11, 15], gamma=0.5)

    # forward and backward
    best_loss = 3
    n_iter_t, n_iter_v = 0, 0
    if args.dataset == 'matterport':
        total_iter_t = 105432 / args.batch_size
    elif args.dataset == 'scannet':
        total_iter_t = 59743 / args.batch_size
    else:
        total_iter_t = 0

    if not os.path.exists(args.model_savepath):
        os.makedirs(args.model_savepath)

    for epoch in range(args.n_epoch):

        scheduler_RGB.step()
        model.train()

        for i, (images, labels, masks, valids, depthes, meshdepthes) in enumerate(trainloader):
            n_iter_t += 1

            images = Variable(images.contiguous().cuda())
            labels = Variable(labels.contiguous().cuda())
            masks = Variable(masks.contiguous().cuda())

            optimizer_RGB.zero_grad()
            if args.input == 'rgb':
                outputs = model(images)
            else:
                depthes = Variable(depthes.contiguous().cuda())
                if args.input == 'rgbd':
                    rgbd_input = torch.cat((images, depthes), dim=1)
                    outputs = model(rgbd_input)
                elif args.input == 'd':
                    outputs = model(depthes)

            loss, df = get_lossfun(args.loss, outputs, labels, masks)
            if args.l1regular:
                loss_rgl, df_rgl = get_lossfun('l1gra', outputs, labels, masks)
            elif args.gradloss:
                loss_grad, df_grad = get_lossfun('gradmap', outputs, labels, masks)

            if args.l1regular:
                outputs.backward(gradient=df, retain_graph=True)
                outputs.backward(gradient=0.1 * df_rgl)
            elif args.gradloss:
                outputs.backward(gradient=df, retain_graph=True)
                outputs.backward(gradient=0.5 * df_grad)
            else:
                outputs.backward(gradient=df)

            optimizer_RGB.step()

            if (i + 1) % 100 == 0:
                if args.l1regular:
                    print("Epoch [%d/%d] Iter [%d/%d] Loss and RGL: %.4f, %.4f" % (
                        epoch + 1, args.n_epoch, i, total_iter_t, loss.data, loss_rgl.data))
                elif args.gradloss:
                    print("Epoch [%d/%d] Iter [%d/%d] Loss and GradLoss: %.4f, %.4f" % (
                        epoch + 1, args.n_epoch, i, total_iter_t, loss.data, loss_grad.data))
                else:
                    print("Epoch [%d/%d] Iter [%d/%d] Loss: %.4f" % (
                        epoch + 1, args.n_epoch, i, total_iter_t, loss.data))

            if (i + 1) % 250 == 0:
                writer.add_scalar('loss/trainloss', loss.data.item(), n_iter_t)
                if args.l1regular:
                    writer.add_scalar('loss/trainloss_rgl', loss_rgl.data.item(), n_iter_t)
                elif args.gradloss:
                    writer.add_scalar('loss/trainloss_grad', loss_grad.data.item(), n_iter_t)

                writer.add_images('Image', images + 0.5, n_iter_t)

                if args.input != 'rgb':
                    writer.add_images('Depth', np.repeat(
                        (depthes - torch.min(depthes)) / (torch.max(depthes) - torch.min(depthes)), 3, axis=1),
                                      n_iter_t)
                writer.add_images('Label', 0.5 * (labels.permute(0, 3, 1, 2) + 1), n_iter_t)
                outputs_n = norm_tf(outputs)
                writer.add_images('Output', outputs_n, n_iter_t)

        model.eval()
        mean_loss, sum_loss, sum_rgl, sum_grad = 0, 0, 0, 0
        evalcount = 0
        with torch.no_grad():
            for i_val, (images_val, labels_val, masks_val, valids_val, depthes_val, meshdepthes_val) in tqdm(
                    enumerate(evalloader)):
                n_iter_v += 1
                images_val = Variable(images_val.contiguous().cuda())
                labels_val = Variable(labels_val.contiguous().cuda())
                masks_val = Variable(masks_val.contiguous().cuda())

                if args.input == 'rgb':
                    outputs = model(images_val)
                else:
                    depthes_val = Variable(depthes_val.contiguous().cuda())
                    if args.input == 'rgbd':
                        rgbd_input = torch.cat((images_val, depthes_val), dim=1)
                        outputs = model(rgbd_input)
                    elif args.input == 'd':
                        outputs = model(depthes_val)

                loss, df = get_lossfun(args.loss, outputs, labels_val, masks_val, False)  # valid_val not used infact
                if args.l1regular:
                    loss_rgl, df_rgl = get_lossfun('l1gra', outputs, labels_val, masks_val, False)
                elif args.gradloss:
                    loss_grad, df_grad = get_lossfun('gradmap', outputs, labels_val, masks_val, False)

                if ((np.isnan(loss)) | (np.isinf(loss))):
                    sum_loss += 0
                else:
                    sum_loss += loss
                    evalcount += 1
                    if args.l1regular:
                        sum_rgl += loss_rgl
                    elif args.gradloss:
                        sum_grad += loss_grad

                if (i_val + 1) % 250 == 0:
                    # print("Epoch [%d/%d] Evaluation Loss: %.4f" % (epoch+1, args.n_epoch, loss))
                    writer.add_scalar('loss/evalloss', loss, n_iter_v)

                    writer.add_images('Eval Image', images_val + 0.5, n_iter_t)
                    if args.input != 'rgb':
                        writer.add_image('Depth', np.repeat(
                            (depthes_val - torch.min(depthes_val)) / (torch.max(depthes_val) - torch.min(depthes_val)),
                            3, axis=1), n_iter_t)
                    writer.add_images('Eval Label', 0.5 * (labels_val.permute(0, 3, 1, 2) + 1), n_iter_t)
                    outputs_n = norm_tf(outputs)
                    writer.add_images('Eval Output', outputs_n, n_iter_t)

            mean_loss = sum_loss / evalcount
            print("Epoch [%d/%d] Evaluation Mean Loss: %.4f" % (epoch + 1, args.n_epoch, mean_loss))
            writer.add_scalar('loss/evalloss_mean', mean_loss, epoch)
            writer.add_scalar('loss/evalloss_rgl_mean', sum_rgl / evalcount, epoch)
            writer.add_scalar('loss/evalloss_grad_mean', sum_grad / evalcount, epoch)

        if mean_loss < best_loss:  # if (epoch+1)%20 == 0:
            best_loss = mean_loss
            state = {'epoch': epoch + 1,
                     'model_RGB_state': model.state_dict(),
                     'optimizer_RGB_state': optimizer_RGB.state_dict(), }
            if args.pretrain:
                if args.l1regular:
                    torch.save(state, pjoin(args.model_savepath,
                                            "{}_{}_{}_{}_rgls_best.pkl".format(args.arch_RGB, args.dataset, args.loss,
                                                                               args.model_num)))
                elif args.gradloss:
                    torch.save(state, pjoin(args.model_savepath,
                                            "{}_{}_{}_{}_grad_best.pkl".format(args.arch_RGB, args.dataset, args.loss,
                                                                               args.model_num)))
                else:
                    torch.save(state, pjoin(args.model_savepath,
                                            "{}_{}_{}_{}_resume_RGB_best.pkl".format(args.arch_RGB, args.dataset,
                                                                                     args.loss, args.model_num)))
            else:
                torch.save(state, pjoin(args.model_savepath,
                                        "{}_{}_{}_{}_resume_RGB_best.pkl".format(args.arch_RGB, args.dataset, args.loss,
                                                                                 args.model_num)))

    print('Finish training for dataset %s trial %s' % (args.dataset, args.model_num))
    # state = {'epoch': epoch+1,
    #                  'model_RGB_state': model_RGB.state_dict(),
    #                  'optimizer_RGB_state' : optimizer_RGB.state_dict(),}
    # if args.pretrain:
    #     torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_RGB_final.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num)))
    # elif args.l1regular:
    #     torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_rgls_final.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num)))
    # else:
    #     torch.save(state, pjoin(args.model_savepath, "{}_{}_{}_{}_nopretrain_final.pkl".format(args.arch_RGB, args.dataset, args.loss, args.model_num)))                    

    writer.export_scalars_to_json("./{}_{}_{}_{}.json".format(args.arch_RGB, args.dataset, args.loss, args.model_num))
    writer.close()
Exemplo n.º 27
0
    fd_optimizer.step()

    pd_optimizer.zero_grad()
    pd_loss.backward()
    pd_optimizer.step()

    if (i + 1) % args.save_model_interval == 0 or (i + 1) == args.max_iter:
        #torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_{i + 1}.pth')
        #torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_{i + 1}.pth')
        #torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_{i + 1}.pth')
        torch.save(g_model.state_dict(), f'{args.save_dir}/ckpt/G_10000.pth')
        torch.save(fd_model.state_dict(), f'{args.save_dir}/ckpt/FD_10000.pth')
        torch.save(pd_model.state_dict(), f'{args.save_dir}/ckpt/PD_10000.pth')

    if (i + 1) % args.log_interval == 0:
        writer.add_scalar('g_loss/recon_loss', recon_loss.item(), i + 1)
        writer.add_scalar('g_loss/cons_loss', cons_loss.item(), i + 1)
        writer.add_scalar('g_loss/gan_loss', gan_loss.item(), i + 1)
        writer.add_scalar('g_loss/total_loss', total_loss.item(), i + 1)
        writer.add_scalar('d_loss/fd_loss', fd_loss.item(), i + 1)
        writer.add_scalar('d_loss/pd_loss', pd_loss.item(), i + 1)

    def denorm(x):
        out = (x + 1) / 2 # [-1,1] -> [0,1]
        return out.clamp_(0, 1)
    if (i + 1) % args.vis_interval == 0:
        ims = torch.cat([img, masked, refine_result], dim=3)
        writer.add_images('raw_masked_refine', denorm(ims), i + 1)

writer.close()
def main(config):
    matrix = torch.load("matrix_obj_vs_att.pt")
    cudnn.benchmark = True
    device = torch.device('cuda:1')

    log_save_dir, model_save_dir, sample_save_dir, result_save_dir = prepare_dir(
        config.exp_name)

    attribute_nums = 106

    data_loader, _ = get_dataloader_vg(batch_size=config.batch_size,
                                       attribute_embedding=attribute_nums,
                                       image_size=config.image_size)

    vocab_num = data_loader.dataset.num_objects

    if config.clstm_layers == 0:
        netG = Generator_nolstm(num_embeddings=vocab_num,
                                embedding_dim=config.embedding_dim,
                                z_dim=config.z_dim).to(device)
    else:
        netG = Generator(num_embeddings=vocab_num,
                         obj_att_dim=config.embedding_dim,
                         z_dim=config.z_dim,
                         clstm_layers=config.clstm_layers,
                         obj_size=config.object_size,
                         attribute_dim=attribute_nums).to(device)

    netD_image = ImageDiscriminator(conv_dim=config.embedding_dim).to(device)
    netD_object = ObjectDiscriminator(n_class=vocab_num).to(device)
    netD_att = AttributeDiscriminator(n_attribute=attribute_nums).to(device)

    netD_image = add_sn(netD_image)
    netD_object = add_sn(netD_object)
    netD_att = add_sn(netD_att)

    netG_optimizer = torch.optim.Adam(netG.parameters(), config.learning_rate,
                                      [0.5, 0.999])
    netD_image_optimizer = torch.optim.Adam(netD_image.parameters(),
                                            config.learning_rate, [0.5, 0.999])
    netD_object_optimizer = torch.optim.Adam(netD_object.parameters(),
                                             config.learning_rate,
                                             [0.5, 0.999])
    netD_att_optimizer = torch.optim.Adam(netD_att.parameters(),
                                          config.learning_rate, [0.5, 0.999])

    start_iter_ = load_model(netD_object,
                             model_dir=model_save_dir,
                             appendix='netD_object',
                             iter=config.resume_iter)

    start_iter_ = load_model(netD_att,
                             model_dir=model_save_dir,
                             appendix='netD_attribute',
                             iter=config.resume_iter)

    start_iter_ = load_model(netD_image,
                             model_dir=model_save_dir,
                             appendix='netD_image',
                             iter=config.resume_iter)

    start_iter = load_model(netG,
                            model_dir=model_save_dir,
                            appendix='netG',
                            iter=config.resume_iter)

    data_iter = iter(data_loader)

    if start_iter < config.niter:

        if config.use_tensorboard: writer = SummaryWriter(log_save_dir)

        for i in range(start_iter, config.niter):

            # =================================================================================== #
            #                             1. Preprocess input data                                #
            # =================================================================================== #
            try:
                batch = next(data_iter)
            except:
                data_iter = iter(data_loader)
                batch = next(data_iter)

            imgs, objs, boxes, masks, obj_to_img, attribute, masks_shift, boxes_shift = batch
            z = torch.randn(objs.size(0), config.z_dim)

            att_idx = attribute.sum(dim=1).nonzero().squeeze()
            # print("Train D")
            # =================================================================================== #
            #                             2. Train the discriminator                              #
            # =================================================================================== #
            imgs, objs, boxes, masks, obj_to_img, z, attribute, masks_shift, boxes_shift \
                = imgs.to(device), objs.to(device), boxes.to(device), masks.to(device), obj_to_img, z.to(
                device), attribute.to(device), masks_shift.to(device), boxes_shift.to(device)

            attribute_GT = attribute.clone()

            # estimate attributes
            attribute_est = attribute.clone()
            att_mask = torch.zeros(attribute.shape[0])
            att_mask = att_mask.scatter(0, att_idx, 1).to(device)

            crops_input = crop_bbox_batch(imgs, boxes, obj_to_img,
                                          config.object_size)
            estimated_att = netD_att(crops_input)
            max_idx = estimated_att.argmax(1)
            max_idx = max_idx.float() * (~att_mask.byte()).float().to(device)
            for row in range(attribute.shape[0]):
                if row not in att_idx:
                    attribute_est[row, int(max_idx[row])] = 1

            # change GT attribute:
            num_img_to_change = math.floor(imgs.shape[0] / 3)
            for img_idx in range(num_img_to_change):
                obj_indices = torch.nonzero(obj_to_img == img_idx).view(-1)

                num_objs_to_change = math.floor(len(obj_indices) / 2)
                for changed, obj_idx in enumerate(obj_indices):
                    if changed >= num_objs_to_change:
                        break
                    obj = objs[obj_idx]
                    # change GT attribute
                    old_attributes = torch.nonzero(
                        attribute_GT[obj_idx]).view(-1)
                    new_attribute = random.choices(range(106),
                                                   matrix[obj].scatter(
                                                       0, old_attributes.cpu(),
                                                       0),
                                                   k=random.randrange(1, 3))
                    attribute[obj_idx] = 0  # remove all attributes for obj
                    attribute[obj_idx] = attribute[obj_idx].scatter(
                        0,
                        torch.LongTensor(new_attribute).to(device),
                        1)  # assign new attribute

                    # change estimated attributes
                    attribute_est[obj_idx] = 0  # remove all attributes for obj
                    attribute_est[obj_idx] = attribute[obj_idx].scatter(
                        0,
                        torch.LongTensor(new_attribute).to(device), 1)

            # Generate fake image
            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec.detach())
            d_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            out_logits = netD_image(img_rand.detach())
            d_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # shift image adv loss
            out_logits = netD_image(img_shift.detach())
            d_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_image_adv_loss_fake = 0.4 * d_image_adv_loss_fake_rec + 0.4 * d_image_adv_loss_fake_rand + 0.2 * d_image_adv_loss_fake_shift

            # Compute image src loss with real images rec.
            out_logits = netD_image(imgs)
            d_image_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # Compute object sn adv loss with fake rec crops
            out_logits, _ = netD_object(crops_input_rec.detach(), objs)
            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # Compute object sn adv loss with fake rand crops
            out_logits, _ = netD_object(crops_rand.detach(), objs)

            d_object_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            # shift obj adv loss
            out_logits, _ = netD_object(crops_shift.detach(), objs)
            d_object_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 0))

            d_object_adv_loss_fake = 0.4 * g_object_adv_loss_rec + 0.4 * d_object_adv_loss_fake_rand + 0.2 * d_object_adv_loss_fake_shift

            # Compute object sn adv loss with real crops.
            out_logits_src, out_logits_cls = netD_object(
                crops_input.detach(), objs)

            d_object_adv_loss_real = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))

            # cls
            d_object_cls_loss_real = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_input.detach())
            att_idx = attribute_GT.sum(dim=1).nonzero().squeeze()
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            attribute_annotated = torch.index_select(attribute_GT, 0, att_idx)
            d_object_att_cls_loss_real = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            # Backward and optimize.
            d_loss = 0
            d_loss += config.lambda_img_adv * (d_image_adv_loss_fake +
                                               d_image_adv_loss_real)
            d_loss += config.lambda_obj_adv * (d_object_adv_loss_fake +
                                               d_object_adv_loss_real)
            d_loss += config.lambda_obj_cls * d_object_cls_loss_real
            d_loss += config.lambda_att_cls * d_object_att_cls_loss_real

            netD_image.zero_grad()
            netD_object.zero_grad()
            netD_att.zero_grad()

            d_loss.backward()

            netD_image_optimizer.step()
            netD_object_optimizer.step()
            netD_att_optimizer.step()

            # Logging.
            loss = {}
            loss['D/loss'] = d_loss.item()
            loss['D/image_adv_loss_real'] = d_image_adv_loss_real.item()
            loss['D/image_adv_loss_fake'] = d_image_adv_loss_fake.item()
            loss['D/object_adv_loss_real'] = d_object_adv_loss_real.item()
            loss['D/object_adv_loss_fake'] = d_object_adv_loss_fake.item()
            loss['D/object_cls_loss_real'] = d_object_cls_loss_real.item()
            loss['D/object_att_cls_loss'] = d_object_att_cls_loss_real.item()

            # print("train G")
            # =================================================================================== #
            #                               3. Train the generator                                #
            # =================================================================================== #
            # Generate fake image

            output = netG(imgs, objs, boxes, masks, obj_to_img, z, attribute,
                          masks_shift, boxes_shift, attribute_est)
            crops_input, crops_input_rec, crops_rand, crops_shift, img_rec, img_rand, img_shift, mu, logvar, z_rand_rec, z_rand_shift = output

            # reconstruction loss of ae and img
            rec_img_mask = torch.ones(imgs.shape[0]).scatter(
                0, torch.LongTensor(range(num_img_to_change)), 0).to(device)
            g_img_rec_loss = rec_img_mask * torch.abs(img_rec - imgs).view(
                imgs.shape[0], -1).mean(1)
            g_img_rec_loss = g_img_rec_loss.sum() / (imgs.shape[0] -
                                                     num_img_to_change)

            g_z_rec_loss_rand = torch.abs(z_rand_rec - z).mean()
            g_z_rec_loss_shift = torch.abs(z_rand_shift - z).mean()
            g_z_rec_loss = 0.5 * g_z_rec_loss_rand + 0.5 * g_z_rec_loss_shift

            # kl loss
            kl_element = mu.pow(2).add_(
                logvar.exp()).mul_(-1).add_(1).add_(logvar)
            g_kl_loss = torch.sum(kl_element).mul_(-0.5)

            # Compute image adv loss with fake images.
            out_logits = netD_image(img_rec)

            g_image_adv_loss_fake_rec = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            out_logits = netD_image(img_rand)
            g_image_adv_loss_fake_rand = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            # shift image adv loss
            out_logits = netD_image(img_shift)
            g_image_adv_loss_fake_shift = F.binary_cross_entropy_with_logits(
                out_logits, torch.full_like(out_logits, 1))

            g_image_adv_loss_fake = 0.4 * g_image_adv_loss_fake_rec + 0.4 * g_image_adv_loss_fake_rand + 0.2 * g_image_adv_loss_fake_shift

            # Compute object adv loss with fake images.
            out_logits_src, out_logits_cls = netD_object(crops_input_rec, objs)

            g_object_adv_loss_rec = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rec = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_input_rec)
            att_idx = attribute.sum(dim=1).nonzero().squeeze()
            attribute_annotated = torch.index_select(attribute, 0, att_idx)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_rec = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            out_logits_src, out_logits_cls = netD_object(crops_rand, objs)
            g_object_adv_loss_rand = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))
            g_object_cls_loss_rand = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_rand)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_rand = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            # shift adv obj loss
            out_logits_src, out_logits_cls = netD_object(crops_shift, objs)
            g_object_adv_loss_shift = F.binary_cross_entropy_with_logits(
                out_logits_src, torch.full_like(out_logits_src, 1))

            g_object_cls_loss_shift = F.cross_entropy(out_logits_cls, objs)
            # attribute
            att_cls = netD_att(crops_shift)
            att_cls_annotated = torch.index_select(att_cls, 0, att_idx)
            g_object_att_cls_loss_shift = F.binary_cross_entropy_with_logits(
                att_cls_annotated,
                attribute_annotated,
                pos_weight=pos_weight.to(device))

            g_object_att_cls_loss = 0.4 * g_object_att_cls_loss_rec + 0.4 * g_object_att_cls_loss_rand + 0.2 * g_object_att_cls_loss_shift

            g_object_adv_loss = 0.4 * g_object_adv_loss_rec + 0.4 * g_object_adv_loss_rand + 0.2 * g_object_adv_loss_shift
            g_object_cls_loss = 0.4 * g_object_cls_loss_rec + 0.4 * g_object_cls_loss_rand + 0.2 * g_object_cls_loss_shift

            # Backward and optimize.
            g_loss = 0
            g_loss += config.lambda_img_rec * g_img_rec_loss
            g_loss += config.lambda_z_rec * g_z_rec_loss
            g_loss += config.lambda_img_adv * g_image_adv_loss_fake
            g_loss += config.lambda_obj_adv * g_object_adv_loss
            g_loss += config.lambda_obj_cls * g_object_cls_loss
            g_loss += config.lambda_att_cls * g_object_att_cls_loss
            g_loss += config.lambda_kl * g_kl_loss

            netG.zero_grad()

            g_loss.backward()

            netG_optimizer.step()

            loss['G/loss'] = g_loss.item()
            loss['G/image_adv_loss'] = g_image_adv_loss_fake.item()
            loss['G/object_adv_loss'] = g_object_adv_loss.item()
            loss['G/object_cls_loss'] = g_object_cls_loss.item()
            loss['G/rec_img'] = g_img_rec_loss.item()
            loss['G/rec_z'] = g_z_rec_loss.item()
            loss['G/kl'] = g_kl_loss.item()
            loss['G/object_att_cls_loss'] = g_object_att_cls_loss.item()

            # =================================================================================== #
            #                               4. Log                                                #
            # =================================================================================== #
            if (i + 1) % config.log_step == 0:
                log = 'iter [{:06d}/{:06d}]'.format(i + 1, config.niter)
                for tag, roi_value in loss.items():
                    log += ", {}: {:.4f}".format(tag, roi_value)
                print(log)

            if (i + 1
                ) % config.tensorboard_step == 0 and config.use_tensorboard:
                for tag, roi_value in loss.items():

                    writer.add_scalar(tag, roi_value, i + 1)
                writer.add_images(
                    'Result/crop_real',
                    imagenet_deprocess_batch(crops_input).float() / 255, i + 1)
                writer.add_images(
                    'Result/crop_real_rec',
                    imagenet_deprocess_batch(crops_input_rec).float() / 255,
                    i + 1)
                writer.add_images(
                    'Result/crop_rand',
                    imagenet_deprocess_batch(crops_rand).float() / 255, i + 1)
                writer.add_images('Result/img_real',
                                  imagenet_deprocess_batch(imgs).float() / 255,
                                  i + 1)
                writer.add_images(
                    'Result/img_real_rec',
                    imagenet_deprocess_batch(img_rec).float() / 255, i + 1)
                writer.add_images(
                    'Result/img_fake_rand',
                    imagenet_deprocess_batch(img_rand).float() / 255, i + 1)

            if (i + 1) % config.save_step == 0:

                # netG_noDP.load_state_dict(new_state_dict)
                save_model(netG,
                           model_dir=model_save_dir,
                           appendix='netG',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_image,
                           model_dir=model_save_dir,
                           appendix='netD_image',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_object,
                           model_dir=model_save_dir,
                           appendix='netD_object',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)
                save_model(netD_att,
                           model_dir=model_save_dir,
                           appendix='netD_attribute',
                           iter=i + 1,
                           save_num=2,
                           save_step=config.save_step)

        if config.use_tensorboard: writer.close()
                    image,
                    func_label2color=visualization.VOClabel2colormap,
                    threshold=None,
                    norm=False)
                # MASK = eq_mask[0].detach().cpu().numpy().astype(np.uint8)*255
                loss_dict = {
                    'loss': loss.item(),
                    'loss_cls': loss_cls.item(),
                    'loss_er': loss_er.item(),
                    'loss_ecr': loss_ecr.item()
                }
                itr = optimizer.global_step - 1
                tblogger.add_scalars('loss', loss_dict, itr)
                tblogger.add_scalar('lr', optimizer.param_groups[0]['lr'], itr)
                tblogger.add_image('Image', input_img, itr)
                # tblogger.add_image('Mask', MASK, itr)
                tblogger.add_image('CLS1', CLS1, itr)
                tblogger.add_image('CLS2', CLS2, itr)
                tblogger.add_image('CLS_RV1', CLS_RV1, itr)
                tblogger.add_image('CLS_RV2', CLS_RV2, itr)
                tblogger.add_images('CAM1', CAM1, itr)
                tblogger.add_images('CAM2', CAM2, itr)
                tblogger.add_images('CAM_RV1', CAM_RV1, itr)
                tblogger.add_images('CAM_RV2', CAM_RV2, itr)

        else:
            print('')
            timer.reset_stage()

    torch.save(model.module.state_dict(), args.session_name + '.pth')
def train_gan(dataloader, model_folder, netG, netD, netS, netEs, netEb, args):
    """
    Parameters:
    ----------
    dataloader: 
        data loader. refers to fuel.dataset
    model_root: 
        the folder to save the models weights
    netG:
        Generator
    netD:
        Descriminator
    netS:
        Segmentation Network
    netEs:
        Segmentation Encoder
    netEb:
        Background Encoder
    """

    d_lr = args.d_lr
    g_lr = args.g_lr
    tot_epoch = args.maxepoch
    ''' configure optimizers '''
    optimizerD = optim.Adam(netD.parameters(), lr=d_lr, betas=(0.5, 0.999))
    paramsG = list(netG.parameters()) + list(netEs.parameters()) + list(
        netEb.parameters())

    optimizerG = optim.Adam(paramsG, lr=g_lr, betas=(0.5, 0.999))
    ''' create tensorboard writer '''
    writer = SummaryWriter(model_folder)

    # --- load model from  checkpoint ---
    netS.load_state_dict(torch.load(args.unet_checkpoint))
    if args.reuse_weights:
        G_weightspath = os.path.join(
            model_folder, 'G_epoch{}.pth'.format(args.load_from_epoch))
        D_weightspath = os.path.join(
            model_folder, 'D_epoch{}.pth'.format(args.load_from_epoch))
        Es_weightspath = os.path.join(
            model_folder, 'Es_epoch{}.pth'.format(args.load_from_epoch))
        Eb_weightspath = os.path.join(
            model_folder, 'Eb_epoch{}.pth'.format(args.load_from_epoch))

        netG.load_state_dict(torch.load(G_weightspath))
        netD.load_state_dict(torch.load(D_weightspath))
        netEs.load_state_dict(torch.load(Es_weightspath))
        netEb.load_state_dict(torch.load(Eb_weightspath))

        start_epoch = args.load_from_epoch + 1
        d_lr /= 2**(start_epoch // args.epoch_decay)
        g_lr /= 2**(start_epoch // args.epoch_decay)

    else:
        start_epoch = 1

    # --- Start training ---
    for epoch in range(start_epoch, tot_epoch + 1):
        start_timer = time.time()
        '''decay learning rate every epoch_decay epoches'''
        if epoch % args.epoch_decay == 0:
            d_lr = d_lr / 2
            g_lr = g_lr / 2

            set_lr(optimizerD, d_lr)
            set_lr(optimizerG, g_lr)

        netG.train()
        netD.train()
        netEs.train()
        netEb.train()
        netS.eval()

        for i, data in enumerate(dataloader):
            images, w_images, segs, txt_data, txt_len, _ = data

            # create labels
            r_labels = torch.FloatTensor(images.size(0)).fill_(1).cuda()
            f_labels = torch.FloatTensor(images.size(0)).fill_(0).cuda()

            it = epoch * len(dataloader) + i

            # to cuda
            images = images.cuda()
            w_images = w_images.cuda()
            segs = segs.cuda()
            txt_data = txt_data.cuda()
            ''' UPDATE D '''
            for p in netD.parameters():
                p.requires_grad = True
            optimizerD.zero_grad()

            if args.manipulate:
                bimages = images  # for text and seg mismatched backgrounds
                bsegs = segs  # background segmentations
            else:
                bimages = roll(
                    images, 2,
                    dim=0)  # for text and seg mismatched backgrounds
                bsegs = roll(segs, 2, dim=0)  # background segmentations
                segs = roll(segs, 1,
                            dim=0)  # for text mismatched segmentations

            segs_code = netEs(segs)  # segmentation encoding
            bkgs_code = netEb(bimages)  # background image encoding

            mean_var, smean_var, bmean_var, f_images, z_list = netG(
                txt_data, txt_len, segs_code, bkgs_code)

            f_images_cp = f_images.data.cuda()

            r_logit, r_logit_c = netD(images, txt_data, txt_len)
            _, w_logit_c = netD(w_images, txt_data, txt_len)
            f_logit, _ = netD(f_images_cp, txt_data, txt_len)

            d_adv_loss = compute_d_loss(r_logit, r_logit_c, w_logit_c, f_logit,
                                        r_labels, f_labels)

            d_loss = d_adv_loss
            d_loss.backward()
            optimizerD.step()
            optimizerD.zero_grad()
            ''' UPDATE G '''
            for p in netD.parameters():
                p.requires_grad = False  # to avoid computation
            optimizerG.zero_grad()

            f_logit, f_logit_c = netD(f_images, txt_data, txt_len)

            g_adv_loss = compute_g_loss(f_logit, f_logit_c, r_labels)
            f_segs = netS(f_images)  # segmentation from Unet
            seg_consist_loss = shape_consistency_loss(f_segs, segs)
            bkg_consist_loss = background_consistency_loss(
                f_images, bimages, f_segs, bsegs)

            kl_loss = get_kl_loss(mean_var[0], mean_var[1])  # text
            skl_loss = get_kl_loss(smean_var[0], smean_var[1])  # segmentation
            bkl_loss = get_kl_loss(bmean_var[0], bmean_var[1])  # background

            if args.manipulate:
                idt_consist_loss = idt_consistency_loss(f_images, images)
            else:
                idt_consist_loss = 0.

            g_loss = g_adv_loss \
                    + args.KL_COE * kl_loss \
                    + args.KL_COE * skl_loss \
                    + args.KL_COE * bkl_loss \
                    + 10 * seg_consist_loss \
                    + 10 * bkg_consist_loss \
                    + 10 * idt_consist_loss

            g_loss.backward()
            optimizerG.step()
            optimizerG.zero_grad()

            # --- visualize train samples----
            if it % args.verbose_per_iter == 0:
                writer.add_images('txt', (images[:args.n_plots] + 1) / 2, it)
                writer.add_images('background',
                                  (bimages[:args.n_plots] + 1) / 2, it)
                writer.add_images('segmentation',
                                  segs[:args.n_plots].repeat(1, 3, 1, 1), it)
                writer.add_images('generated',
                                  (f_images[:args.n_plots] + 1) / 2, it)
                writer.add_scalar('g_lr', g_lr, it)
                writer.add_scalar('d_lr', g_lr, it)
                writer.add_scalar('g_loss', to_numpy(g_loss).mean(), it)
                writer.add_scalar('d_loss', to_numpy(d_loss).mean(), it)
                writer.add_scalar('imkl_loss', to_numpy(kl_loss).mean(), it)
                writer.add_scalar('segkl_loss', to_numpy(skl_loss).mean(), it)
                writer.add_scalar('bkgkl_loss', to_numpy(bkl_loss).mean(), it)
                writer.add_scalar('seg_consist_loss',
                                  to_numpy(seg_consist_loss).mean(), it)
                writer.add_scalar('bkg_consist_loss',
                                  to_numpy(bkg_consist_loss).mean(), it)
                if args.manipulate:
                    writer.add_scalar('idt_consist_loss',
                                      to_numpy(idt_consist_loss).mean(), it)

        # --- save weights ---
        if epoch % args.save_freq == 0:

            netG = netG.cpu()
            netD = netD.cpu()
            netEs = netEs.cpu()
            netEb = netEb.cpu()

            torch.save(
                netD.state_dict(),
                os.path.join(model_folder, 'D_epoch{}.pth'.format(epoch)))
            torch.save(
                netG.state_dict(),
                os.path.join(model_folder, 'G_epoch{}.pth'.format(epoch)))
            torch.save(
                netEs.state_dict(),
                os.path.join(model_folder, 'Es_epoch{}.pth'.format(epoch)))
            torch.save(
                netEb.state_dict(),
                os.path.join(model_folder, 'Eb_epoch{}.pth'.format(epoch)))

            print('save weights at {}'.format(model_folder))
            netD = netD.cuda()
            netG = netG.cuda()
            netEs = netEs.cuda()
            netEb = netEb.cuda()

        end_timer = time.time() - start_timer
        print('epoch {}/{} finished [time = {}s] ...'.format(
            epoch, tot_epoch, end_timer))

    writer.close()