Beispiel #1
0
def load_model():
    comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
    comp = nn.DataParallel(comp).cuda()
    comp.load_state_dict(
        torch.load(os.path.join(assets_file_dir, "unfiller_rgb.pth")))

    model = comp.module
    model.eval()
    return model
    def __init__(self,
                 port,
                 imgs,
                 depths,
                 target,
                 target_poses,
                 scale_up,
                 semantics=None,
                 gui=True,
                 use_filler=True,
                 gpu_idx=0,
                 window_width=256,
                 window_height=None,
                 env=None):

        self.env = env
        self.roll, self.pitch, self.yaw = 0, 0, 0
        self.quat = [1, 0, 0, 0]
        self.x, self.y, self.z = 0, 0, 0
        self.fps = 0
        self.mousex, self.mousey = 0.5, 0.5
        self.org_pitch, self.org_yaw, self.org_roll = 0, 0, 0
        self.org_x, self.org_y, self.org_z = 0, 0, 0
        self.clickstart = (0, 0)
        self.mousedown = False
        self.overlay = False
        self.show_depth = False

        self.port = port
        self._context_phys = zmq.Context()
        self._context_mist = zmq.Context()
        self._context_dept = zmq.Context()  # Channel for smoothed depth
        self._context_norm = zmq.Context()  # Channel for smoothed depth
        self._context_semt = zmq.Context()
        self.env = env

        # configs.View.SEMANTICS in configs.ViewComponent.getComponents()
        self._require_semantics = 'semantics' in self.env.config["output"]
        # configs.View.NORMAL in configs.ViewComponent.getComponents()
        self._require_normal = 'normal' in self.env.config["output"]

        self.socket_mist = self._context_mist.socket(zmq.REQ)
        self.socket_mist.connect("tcp://localhost:{}".format(self.port - 1))
        # self.socket_dept = self._context_dept.socket(zmq.REQ)
        # self.socket_dept.connect("tcp://localhost:{}".format(5555 - 1))
        if self._require_normal:
            self.socket_norm = self._context_norm.socket(zmq.REQ)
            self.socket_norm.connect("tcp://localhost:{}".format(self.port -
                                                                 2))
        if self._require_semantics:
            self.socket_semt = self._context_semt.socket(zmq.REQ)
            self.socket_semt.connect("tcp://localhost:{}".format(self.port -
                                                                 3))

        self.target_poses = target_poses
        self.pose_locations = np.array(
            [tp[:3, -1] for tp in self.target_poses])

        self.relative_poses = [
            np.dot(np.linalg.inv(tg), self.target_poses[0])
            for tg in target_poses
        ]

        self.imgs = imgs
        self.depths = depths
        self.target = target
        self.semantics = semantics
        self.model = None
        self.old_topk = set([])
        self.k = 5
        self.use_filler = use_filler

        self.window_width = window_width
        if window_height is None:
            self.window_height = window_width
        else:
            self.window_height = window_height
        self.capture_count = 0

        # print(self.showsz)
        # self.show   = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8')
        # self.show_rgb   = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8')

        self.show = np.zeros((self.window_width, self.window_height, 3),
                             dtype='uint8')
        self.show_rgb = np.zeros((self.window_width, self.window_height, 3),
                                 dtype='uint8')
        self.show_semantics = np.zeros(
            (self.window_width, self.window_height, 3), dtype='uint8')

        self.show_prefilled = np.zeros(
            (self.window_width, self.window_height, 3), dtype='uint8')
        self.surface_normal = np.zeros(
            (self.window_width, self.window_height, 3), dtype='uint8')

        self.semtimg_count = 0

        if "fast_lq_render" in self.env.config and self.env.config[
                "fast_lq_render"] == True:
            comp = CompletionNet(norm=nn.BatchNorm2d,
                                 nf=24,
                                 skip_first_bn=True)
        else:
            comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
        comp = torch.nn.DataParallel(comp).cuda()
        # comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model_{}.pth".format(self.env.config["resolution"]))))

        if self.env.config["resolution"] <= 64:
            res = 64
        elif self.env.config["resolution"] <= 128:
            res = 128
        elif self.env.config["resolution"] <= 256:
            res = 256
        else:
            res = 512

        if "fast_lq_render" in self.env.config and self.env.config[
                "fast_lq_render"]:
            comp.load_state_dict(
                torch.load(
                    os.path.join(assets_file_dir,
                                 "model_small_{}.pth".format(res))))
        else:
            comp.load_state_dict(
                torch.load(
                    os.path.join(assets_file_dir, "model_{}.pth".format(res))))

        # comp.load_state_dict(torch.load(os.path.join(file_dir, "models.pth")))
        # comp.load_state_dict(torch.load(os.path.join(file_dir, "model_large.pth")))
        self.model = comp.module
        self.model.eval()

        if not self.env.config["use_filler"]:
            self.model = None

        self.imgs_topk = None
        self.depths_topk = None
        self.relative_poses_topk = None
        self.old_topk = None

        self.imgv = Variable(torch.zeros(1, 3, self.window_height,
                                         self.window_width),
                             volatile=True).cuda()
        self.maskv = Variable(torch.zeros(1, 2, self.window_height,
                                          self.window_width),
                              volatile=True).cuda()
        self.mean = torch.from_numpy(
            np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))
        self.mean = self.mean.view(3, 1, 1).repeat(1, self.window_height,
                                                   self.window_width)

        if gui and not self.env.config["display_ui"]:
            self.renderToScreenSetup()
    for k, v in uuids:
        # print(k,v)
        data = d[v]
        source = data[0][0]
        target = data[1]
        target_depth = data[3]
        source_depth = data[2][0]
        pose = data[-1][0].numpy()
        targets.append(target)
        poses.append(pose)
        sources.append(target)
        source_depths.append(target_depth)

    model = None
    if opt.model != '':
        comp = CompletionNet()
        comp = torch.nn.DataParallel(comp).cuda()
        comp.load_state_dict(torch.load(opt.model))
        model = comp.module
        model.eval()
    print(model)
    print('target', poses, poses[0])
    # print('no.1 pose', poses, poses[1])
    # print(source_depth)
    print(sources[0].shape, source_depths[0].shape)

    show_target(target)

    renderer = PCRenderer(5556, sources, source_depths, target, rts)
    # renderer.renderToScreen(sources, source_depths, poses, models, target, target_depth, rts)
    renderer.renderOffScreenSetup()
Beispiel #4
0
    def __init__(self, port, imgs, depths, target, target_poses, scale_up, semantics=None, \
                 gui=True,  use_filler=True, gpu_count=0, windowsz=256, env = None):

        self.env = env
        self.roll, self.pitch, self.yaw = 0, 0, 0
        self.quat = [1, 0, 0, 0]
        self.x, self.y, self.z = 0, 0, 0
        self.fps = 0
        self.mousex, self.mousey = 0.5, 0.5
        self.org_pitch, self.org_yaw, self.org_roll = 0, 0, 0
        self.org_x, self.org_y, self.org_z = 0, 0, 0
        self.clickstart = (0, 0)
        self.mousedown = False
        self.overlay = False
        self.show_depth = False
        self._context_phys = zmq.Context()
        self._context_mist = zmq.Context()
        self._context_dept = zmq.Context()  ## Channel for smoothed depth
        self._context_norm = zmq.Context()  ## Channel for smoothed depth
        self._context_semt = zmq.Context()

        self.env = env

        self._require_semantics = 'semantics' in self.env.config[
            "output"]  #configs.View.SEMANTICS in configs.ViewComponent.getComponents()
        self._require_normal = 'normal' in self.env.config[
            "output"]  #configs.View.NORMAL in configs.ViewComponent.getComponents()

        self.socket_mist = self._context_mist.socket(zmq.REQ)
        self.socket_mist.connect("tcp://localhost:{}".format(5555 + gpu_count))
        self.socket_dept = self._context_dept.socket(zmq.REQ)
        self.socket_dept.connect("tcp://localhost:{}".format(5555 - 1))
        if self._require_normal:
            self.socket_norm = self._context_norm.socket(zmq.REQ)
            self.socket_norm.connect("tcp://localhost:{}".format(5555 - 2))
        if self._require_semantics:
            self.socket_semt = self._context_semt.socket(zmq.REQ)
            self.socket_semt.connect("tcp://localhost:{}".format(5555 - 3))

        self.target_poses = target_poses
        self.imgs = imgs
        self.depths = depths
        self.target = target
        self.semantics = semantics
        self.model = None
        self.old_topk = set([])
        self.k = 5
        self.use_filler = use_filler

        self.showsz = windowsz
        self.capture_count = 0

        #print(self.showsz)

        #self.show   = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8')
        #self.show_rgb   = np.zeros((self.showsz,self.showsz * 2,3),dtype='uint8')

        self.show = np.zeros((self.showsz, self.showsz, 3), dtype='uint8')
        self.show_rgb = np.zeros((self.showsz, self.showsz, 3), dtype='uint8')
        self.show_semantics = np.zeros((self.showsz, self.showsz, 3),
                                       dtype='uint8')

        #self.show_unfilled  = None
        #if configs.MAKE_VIDEO or configs.HIST_MATCHING:
        self.show_unfilled = np.zeros((self.showsz, self.showsz, 3),
                                      dtype='uint8')
        self.surface_normal = np.zeros((self.showsz, self.showsz, 3),
                                       dtype='uint8')

        self.semtimg_count = 0

        #if configs.USE_SMALL_FILLER:
        #    comp = CompletionNet(norm = nn.BatchNorm2d, nf = 24)
        #    comp = torch.nn.DataParallel(comp).cuda()
        #    comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model.pth")))
        #else:
        comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)
        comp = torch.nn.DataParallel(comp).cuda()
        #comp.load_state_dict(torch.load(os.path.join(assets_file_dir, "model_{}.pth".format(self.env.config["resolution"]))))
        comp.load_state_dict(
            torch.load(
                os.path.join(
                    assets_file_dir,
                    "model_{}.pth".format(self.env.config["resolution"]))))
        #comp.load_state_dict(torch.load(os.path.join(file_dir, "model.pth")))
        #comp.load_state_dict(torch.load(os.path.join(file_dir, "model_large.pth")))
        self.model = comp.module
        self.model.eval()

        if not self.env.config["use_filler"]:
            self.model = None

        self.imgv = Variable(torch.zeros(1, 3, self.showsz, self.showsz),
                             volatile=True).cuda()
        self.maskv = Variable(torch.zeros(1, 2, self.showsz, self.showsz),
                              volatile=True).cuda()
        self.mean = torch.from_numpy(
            np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32))

        if gui and not self.env.config["display_ui"]:  #configs.DISPLAY_UI:
            self.renderToScreenSetup()
Beispiel #5
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataroot', required=True, help='path to dataset')
    parser.add_argument('--debug', action='store_true', help='debug mode')
    parser.add_argument('--imgsize', type=int, default=256, help='image size')
    parser.add_argument('--batchsize', type=int, default=20, help='batchsize')
    parser.add_argument('--workers', type=int, default=9, help='number of workers')
    parser.add_argument('--nepoch', type=int, default=50, help='number of epochs')
    parser.add_argument('--lr', type=float, default=2e-5, help='learning rate, default=0.002')
    parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
    parser.add_argument('--outf', type=str, default="filler_pano_pc_full", help='output folder')
    parser.add_argument('--model', type=str, default="", help='model path')
    parser.add_argument('--cepoch', type=int, default=0, help='current epoch')
    parser.add_argument('--loss', type=str, default="perceptual", help='l1 only')
    parser.add_argument('--init', type=str, default="iden", help='init method')
    parser.add_argument('--l1', type=float, default=0, help='add l1 loss')
    parser.add_argument('--color_coeff', type=float, default=0, help='add color match loss')
    parser.add_argument('--unfiller', action='store_true', help='debug mode')
    parser.add_argument('--joint', action='store_true', help='debug mode')
    parser.add_argument('--use_depth', action='store_true', default=False, help='debug mode')
    parser.add_argument('--zoom', type=int, default=1, help='debug mode')
    parser.add_argument('--patchsize', type=int, default=256, help='debug mode')

    mean = torch.from_numpy(np.array([0.57441127, 0.54226291, 0.50356019]).astype(np.float32)).clone()
    opt = parser.parse_args()
    print(opt)
    writer = SummaryWriter(opt.outf + '/runs/' + datetime.now().strftime('%B%d  %H:%M:%S'))
    try:
        os.makedirs(opt.outf)
    except OSError:
        pass

    zoom = opt.zoom
    patchsize = opt.patchsize

    tf = transforms.Compose([
        transforms.ToTensor(),
    ])

    mist_tf = transforms.Compose([
        transforms.ToTensor(),
    ])

    d = PairDataset(root=opt.dataroot, transform=tf, mist_transform=mist_tf)
    d_test = PairDataset(root=opt.dataroot, transform=tf, mist_transform=mist_tf, train=False)

    cudnn.benchmark = True

    dataloader = torch.utils.data.DataLoader(d, batch_size=opt.batchsize, shuffle=True, num_workers=int(opt.workers),
                                             drop_last=True, pin_memory=False)
    dataloader_test = torch.utils.data.DataLoader(d_test, batch_size=opt.batchsize, shuffle=True,
                                                  num_workers=int(opt.workers), drop_last=True, pin_memory=False)

    img = Variable(torch.zeros(opt.batchsize, 3, 1024, 2048)).cuda()
    maskv = Variable(torch.zeros(opt.batchsize, 2, 1024, 2048)).cuda()
    img_original = Variable(torch.zeros(opt.batchsize, 3, 1024, 2048)).cuda()
    label = Variable(torch.LongTensor(opt.batchsize * 4)).cuda()

    comp = CompletionNet(norm=nn.BatchNorm2d, nf=64)

    current_epoch = opt.cepoch

    comp = torch.nn.DataParallel(comp).cuda()

    if opt.init == 'iden':
        comp.apply(identity_init)
    else:
        comp.apply(weights_init)

    if opt.model != '':
        comp.load_state_dict(torch.load(opt.model))
        # dis.load_state_dict(torch.load(opt.model.replace("G", "D")))
        current_epoch = opt.cepoch

    if opt.unfiller:
        comp2 = CompletionNet(norm=nn.BatchNorm2d, nf=64)
        comp2 = torch.nn.DataParallel(comp2).cuda()
        if opt.model != '':
            comp2.load_state_dict(torch.load(opt.model.replace('G', 'G2')))
        optimizerG2 = torch.optim.Adam(comp2.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    l2 = nn.MSELoss()
    # if opt.loss == 'train_init':
    #    params = list(comp.parameters())
    #    sel = np.random.choice(len(params), len(params)/2, replace=False)
    #    params_sel = [params[i] for i in sel]
    #    optimizerG = torch.optim.Adam(params_sel, lr = opt.lr, betas = (opt.beta1, 0.999))
    #
    # else:
    optimizerG = torch.optim.Adam(comp.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))

    curriculum = (200000, 300000)  # step to start D training and G training, slightly different from the paper
    alpha = 0.004

    errG_data = 0
    errD_data = 0

    vgg16 = models.vgg16(pretrained=False)
    vgg16.load_state_dict(torch.load('vgg16-397923af.pth'))
    feat = vgg16.features
    p = torch.nn.DataParallel(Perceptual(feat)).cuda()

    for param in p.parameters():
        param.requires_grad = False

    imgnet_mean = torch.from_numpy(np.array([0.485, 0.456, 0.406]).astype(np.float32)).clone()
    imgnet_std = torch.from_numpy(np.array([0.229, 0.224, 0.225]).astype(np.float32)).clone()

    imgnet_mean_img = Variable(imgnet_mean.view(1, 3, 1, 1).repeat(opt.batchsize * 4, 1, patchsize, patchsize)).cuda()
    imgnet_std_img = Variable(imgnet_std.view(1, 3, 1, 1).repeat(opt.batchsize * 4, 1, patchsize, patchsize)).cuda()

    test_loader_enum = enumerate(dataloader_test)
    for epoch in range(current_epoch, opt.nepoch):
        for i, data in enumerate(dataloader, 0):
            optimizerG.zero_grad()
            source = data[0]
            source_depth = data[1]
            target = data[2]
            step = i + epoch * len(dataloader)

            mask = (torch.sum(source[:, :3, :, :], 1) > 0).float().unsqueeze(1)
            # img_mean = torch.sum(torch.sum(source[:,:3,:,:], 2),2) / torch.sum(torch.sum(mask, 2),2).view(opt.batchsize,1)

            source[:, :3, :, :] += (1 - mask.repeat(1, 3, 1, 1)) * mean.view(1, 3, 1, 1).repeat(opt.batchsize, 1, 1024,
                                                                                                2048)
            source_depth = source_depth[:, :, :, 0].unsqueeze(1)
            # print(source_depth.size(), mask.size())
            source_depth = torch.cat([source_depth, mask], 1)
            img.data.copy_(source)
            maskv.data.copy_(source_depth)
            img_original.data.copy_(target)
            imgc, maskvc, img_originalc = crop(img, maskv, img_original, zoom, patchsize)
            # from IPython import embed; embed()
            recon = comp(imgc, maskvc)

            if opt.loss == "train_init":
                loss = l2(recon, imgc[:, :3, :, :])
            elif opt.loss == 'l1':
                loss = l2(recon, img_originalc)
            elif opt.loss == 'perceptual':
                loss = l2(p(recon), p(img_originalc).detach()) + opt.l1 * l2(recon, img_originalc)
            elif opt.loss == 'color_stable':
                loss = l2(p(recon.view(recon.size(0) * 3, 1, patchsize, patchsize).repeat(1, 3, 1, 1)),
                          p(img_originalc.view(img_originalc.size(0) * 3, 1, patchsize, patchsize).repeat(1, 3, 1, 1)).detach())
            elif opt.loss == 'color_correction':
                recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img)
                org_percept = p((img_originalc - imgnet_mean_img) / (imgnet_std_img)).detach()
                loss = l2(recon_percept, org_percept)
                for scale in [32]:
                    img_originalc_patch = img_originalc.view(opt.batchsize * 4, 3, patchsize // scale, scale, patchsize // scale,
                                                             scale).transpose(4, 3).contiguous().view(opt.batchsize * 4,
                                                                                                      3, patchsize // scale,
                                                                                                      patchsize // scale, -1)
                    recon_patch = recon.view(opt.batchsize * 4, 3, patchsize // scale, scale, patchsize // scale, scale).transpose(4,
                                                                                                                     3).contiguous().view(
                        opt.batchsize * 4, 3, patchsize // scale, patchsize // scale, -1)
                    img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
                    recon_patch_mean = recon_patch.mean(dim=-1)
                    # recon_patch_cov = []
                    # img_originalc_patch_cov = []

                    # for j in range(3):
                    #    recon_patch_cov.append((recon_patch * recon_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))
                    #    img_originalc_patch_cov.append((img_originalc_patch * img_originalc_patch[:,j:j+1].repeat(1,3,1,1,1)).mean(dim=-1))

                    # recon_patch_cov_cat = torch.cat(recon_patch_cov,1)
                    # img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)

                    color_loss = l2(recon_patch_mean,
                                    img_originalc_patch_mean)  # + l2(recon_patch_cov_cat, img_originalc_patch_cov_cat.detach())

                    loss += opt.color_coeff * color_loss

                    print("color loss %f" % color_loss.data[0])

            loss.backward(retain_graph=True)

            if opt.unfiller:
                optimizerG2.zero_grad()

                recon2 = comp2(img_originalc, maskvc)

                if not opt.joint:
                    recon2_percept = p((recon2 - imgnet_mean_img) / imgnet_std_img)
                    recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img)
                    loss2 = l2(recon2_percept, recon_percept.detach())
                else:
                    recon_percept = p((recon - imgnet_mean_img) / imgnet_std_img)
                    z = Variable(torch.zeros(recon_percept.size()).cuda())
                    recon2_percept = p((recon2 - imgnet_mean_img) / imgnet_std_img)

                    loss2 = l2(recon2_percept - recon_percept, z)

                    loss2 += 0.2 * l2(recon2_percept, org_percept)

                for scale in [32]:
                    img_originalc_patch = recon.detach().view(opt.batchsize * 4, 3, patchsize / scale, scale, patchsize / scale,
                                                              scale).transpose(4, 3).contiguous().view(
                        opt.batchsize * 4, 3, patchsize / scale, patchsize / scale, -1)
                    recon2_patch = recon2.view(opt.batchsize * 4, 3, patchsize / scale, scale, patchsize / scale, scale).transpose(
                        4, 3).contiguous().view(opt.batchsize * 4, 3, patchsize / scale, patchsize / scale, -1)
                    img_originalc_patch_mean = img_originalc_patch.mean(dim=-1)
                    recon2_patch_mean = recon2_patch.mean(dim=-1)
                    recon2_patch_cov = []
                    img_originalc_patch_cov = []

                    for j in range(3):
                        recon2_patch_cov.append(
                            (recon2_patch * recon2_patch[:, j:j + 1].repeat(1, 3, 1, 1, 1)).mean(dim=-1))
                        img_originalc_patch_cov.append(
                            (img_originalc_patch * img_originalc_patch[:, j:j + 1].repeat(1, 3, 1, 1, 1)).mean(dim=-1))

                    recon2_patch_cov_cat = torch.cat(recon2_patch_cov, 1)
                    img_originalc_patch_cov_cat = torch.cat(img_originalc_patch_cov, 1)

                    z = Variable(torch.zeros(img_originalc_patch_mean.size()).cuda())
                    if opt.joint:
                        color_loss = l2(recon2_patch_mean - img_originalc_patch_mean, z)
                    else:
                        color_loss = l2(recon2_patch_mean, img_originalc_patch_mean)

                    loss2 += opt.color_coeff * color_loss

                    print("color loss %f" % color_loss.data[0])

                loss2 = loss2 * 0.3
                loss2.backward(retain_graph=True)
                print("loss2 %f" % loss2.data[0])
                optimizerG2.step()

                if i % 10 == 0:
                    writer.add_scalar('MSEloss2', loss2.data[0], step)

            # from IPython import embed; embed()
            if opt.loss == "train_init":
                for param in comp.parameters():
                    if len(param.size()) == 4:
                        # print(param.size())
                        nk = param.size()[2] // 2
                        if nk > 5:
                            param.grad[:nk, :, :, :] = 0

            optimizerG.step()

            print('[%d/%d][%d/%d] %d MSEloss: %f G_loss %f D_loss %f' % (
            epoch, opt.nepoch, i, len(dataloader), step, loss.data[0], errG_data, errD_data))

            if i % 200 == 0:

                test_i, test_data = next(test_loader_enum)
                if test_i > len(dataloader_test) - 5:
                    test_loader_enum = enumerate(dataloader_test)

                source = test_data[0]
                source_depth = test_data[1]
                target = test_data[2]

                mask = (torch.sum(source[:, :3, :, :], 1) > 0).float().unsqueeze(1)

                source[:, :3, :, :] += (1 - mask.repeat(1, 3, 1, 1)) * mean.view(1, 3, 1, 1).repeat(opt.batchsize, 1,
                                                                                                    1024, 2048)
                source_depth = source_depth[:, :, :, 0].unsqueeze(1)
                source_depth = torch.cat([source_depth, mask], 1)
                img.data.copy_(source)
                maskv.data.copy_(source_depth)
                img_original.data.copy_(target)
                imgc, maskvc, img_originalc = crop(img, maskv, img_original, zoom, patchsize)
                comp.eval()
                recon = comp(imgc, maskvc)
                comp.train()

                if opt.unfiller:
                    comp2.eval()
                    # maskvc.data.fill_(0)
                    recon2 = comp2(img_originalc, maskvc)
                    comp2.train()
                    visual = torch.cat([imgc.data[:, :3, :, :], recon.data, recon2.data, img_originalc.data], 3)
                else:
                    visual = torch.cat([imgc.data[:, :3, :, :], recon.data, img_originalc.data], 3)

                visual = vutils.make_grid(visual, normalize=True)
                writer.add_image('image', visual, step)
                vutils.save_image(visual, '%s/compare%d_%d.png' % (opt.outf, epoch, i), nrow=1)

            if i % 10 == 0:
                writer.add_scalar('MSEloss', loss.data[0], step)
                writer.add_scalar('G_loss', errG_data, step)
                writer.add_scalar('D_loss', errD_data, step)

            if i % 2000 == 0:
                torch.save(comp.state_dict(), '%s/compG_epoch%d_%d.pth' % (opt.outf, epoch, i))

                if opt.unfiller:
                    torch.save(comp2.state_dict(), '%s/compG2_epoch%d_%d.pth' % (opt.outf, epoch, i))