Beispiel #1
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(
                nets, args, x_real, y_org, y_trg, z_trg=z_trg, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(
                nets, args, x_real, y_org, y_trg, x_ref=x_ref, masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(
                nets, args, x_real, y_org, y_trg, z_trgs=[z_trg, z_trg2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(
                nets, args, x_real, y_org, y_trg, x_refs=[x_ref, x_ref2], masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network, nets_ema.mapping_network, beta=0.999)
            moving_average(nets.style_encoder, nets_ema.style_encoder, beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i+1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (elapsed, i+1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([d_losses_latent, d_losses_ref, g_losses_latent, g_losses_ref],
                                        ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join(['%s: [%.4f]' % (key, value) for key, value in all_losses.items()])
                print(log)
                wandb.log({"Losses":all_losses})

            # generate images for debugging
            if (i+1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema, args, inputs=inputs_val, step=i+1)

            # save model checkpoints
            if (i+1) % args.save_every == 0:
                self._save_checkpoint(step=i+1)

            # compute FID and LPIPS if necessary
            if (i+1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i+1, mode='latent')
                calculate_metrics(nets_ema, args, i+1, mode='reference')
Beispiel #2
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        #nets_ema = self.nets_ema
        optims = self.optims

        # resume training if necessary
        #if args.resume_iter > 0:
        #    self._load_checkpoint( str(e) + '_' + str(args.resume_iter) )
        """ define the fetcher for dataloading """
        fetcher_tr = InputFetcher(loaders.train, 'train')
        fetcher_val = InputFetcher(loaders.val, 'val')

        print('Start training...')
        start_time = time.time()

        for e in range(args.epoch):
            for i in range(args.resume_iter,
                           len(fetcher_tr)):  #args.total_iters ):
                """ get input from training and validation from fetcher """
                inputs = next(fetcher_tr)
                inputs_val = next(fetcher_val)
                gt_land, gt, gt_mask, prior = inputs.gt_land, inputs.gt, inputs.gt_mask, inputs.prior
                gt_land = gt_land.detach()
                gt = gt.detach()
                gt_mask = gt_mask.detach()
                prior = prior.detach()
                #gt.shape ...: (batch, sync_t, c, h, w)
                #prior.shape: (batch, sync_t, c*2, h, w)

                gt_land = gt_land.flatten(0, 1)  # (batch*sync_t, c*3, h, w)
                gt = gt.flatten(0, 1)
                gt_mask = gt_mask.flatten(0, 1)
                prior = prior.flatten(0, 1)  # (batch*sync_t, c*2, h, w)
                #utils.save_image(  torch.cat( (gt[:,:3], gt[:,3:]), dim=0 ) , './sample_gt.jpg')
                # utils.save_image(prior.view(5,2,3,256,256).view(10,3,256,256), './sample_prior.jpg')
                """ train the discriminator """
                # d_losses.key(): real, fake, reg
                d_loss, d_losses = compute_d_loss(nets, args, gt_land, gt,
                                                  gt_mask, prior)
                self._reset_grad()
                d_loss.backward()
                optims.discriminator.step()
                """ train the generator """
                # g_losses.key(): adv, recon
                g_loss, g_losses = compute_g_loss(nets, args, gt_land, gt,
                                                  gt_mask, prior)
                self._reset_grad()
                g_loss.backward()
                optims.generator.step()
                optims.style_encoder.step()  ############# style encoder update

                # print out log info
                if (i + 1) % args.print_every == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                    log = "Elapsed time [%s], Iteration [%i/%i], " % (
                        elapsed, i + 1, len(fetcher_tr))
                    all_losses = dict()
                    for loss, prefix in zip([d_losses, g_losses], ['D', 'G']):
                        for key, value in loss.items():
                            all_losses[prefix + key] = value
                    log += ' '.join([
                        '%s: [%.4f]' % (key, value)
                        for key, value in all_losses.items()
                    ])
                    print(log)

                    for k, v in all_losses.items():
                        self.summary_writer.add_scalar(k, v, e * i + i)  # add

                # save model checkpoints
                if (i + 1) % args.save_every == 0:
                    self._save_checkpoint(step=str(e) + '_' + str(i + 1))

                # generate images for debugging
                with torch.no_grad():
                    if (i + 1) % args.sample_every == 0:
                        os.makedirs(args.sample_dir, exist_ok=True)
                        utils.debug_image(nets,
                                          args,
                                          inputs=inputs_val,
                                          step=str(e) + '_' + str(i + 1))
    def train(self, loaders):
        place = paddle.fluid.CUDAPlace(
            self.args.whichgpu) if paddle.fluid.is_compiled_with_cuda(
            ) else paddle.fluid.CPUPlace()
        with fluid.dygraph.guard(place):
            args = self.args
            nets = self.nets
            nets_ema = self.nets_ema
            optims = self.optims

            # fetch random validation images for debugging
            fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim,
                                   'train')
            fetcher_val = InputFetcher(loaders.val, None, args.latent_dim,
                                       'val')
            inputs_val = next(fetcher_val)

            # resume training if necessary
            if args.resume_iter > 0:
                self._load_checkpoint(args.resume_iter)

            # remember the initial value of ds weight
            initial_lambda_ds = args.lambda_ds

            print('Start training...')
            start_time = time.time()
            for i in range(args.resume_iter, args.total_iters):
                # fetch images and labels
                inputs = next(fetcher)
                x_real, y_org = inputs.x_src, inputs.y_src
                x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
                z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2
                x_real = fluid.dygraph.to_variable(x_real.astype('float32'))
                y_org = fluid.dygraph.to_variable(y_org)
                x_ref = fluid.dygraph.to_variable(x_ref.astype('float32'))
                x_ref2 = fluid.dygraph.to_variable(x_ref2.astype('float32'))
                y_trg = fluid.dygraph.to_variable(y_trg)
                z_trg = fluid.dygraph.to_variable(z_trg.astype('float32'))
                z_trg2 = fluid.dygraph.to_variable(z_trg2.astype('float32'))

                masks = nets.fan.get_heatmap(
                    x_real) if args.w_hpf > 0 else None

                # train the discriminator
                # print('train the discriminator')
                d_loss, d_losses_latent = compute_d_loss(nets,
                                                         args,
                                                         x_real,
                                                         y_org,
                                                         y_trg,
                                                         z_trg=z_trg,
                                                         masks=masks)
                self._reset_grad()
                d_loss_avg = fluid.layers.mean(d_loss)
                d_loss_avg.backward()
                optims.discriminator.minimize(d_loss_avg)

                d_loss, d_losses_ref = compute_d_loss(nets,
                                                      args,
                                                      x_real,
                                                      y_org,
                                                      y_trg,
                                                      x_ref=x_ref,
                                                      masks=masks)
                self._reset_grad()
                d_loss_avg = fluid.layers.mean(d_loss)
                d_loss_avg.backward()
                optims.discriminator.minimize(d_loss_avg)

                # train the generator
                # print('train the generator 1st')
                g_loss, g_losses_latent = compute_g_loss(
                    nets,
                    args,
                    x_real,
                    y_org,
                    y_trg,
                    z_trgs=[z_trg, z_trg2],
                    masks=masks)
                self._reset_grad()
                g_loss_avg = fluid.layers.mean(g_loss)
                g_loss_avg.backward()  # stuck here with 1.8.x cpu version
                optims.generator.minimize(g_loss_avg)
                optims.mapping_network.minimize(g_loss_avg)
                optims.style_encoder.minimize(g_loss_avg)

                # print('train the generator 2nd')

                g_loss, g_losses_ref = compute_g_loss(nets,
                                                      args,
                                                      x_real,
                                                      y_org,
                                                      y_trg,
                                                      x_refs=[x_ref, x_ref2],
                                                      masks=masks)
                self._reset_grad()
                g_loss_avg = fluid.layers.mean(g_loss)
                g_loss_avg.backward()
                optims.generator.minimize(g_loss_avg)

                # compute moving average of network parameters
                # print('compute moving average of network parameters')
                moving_average(nets.generator, nets_ema.generator, beta=0.999)
                moving_average(nets.mapping_network,
                               nets_ema.mapping_network,
                               beta=0.999)
                moving_average(nets.style_encoder,
                               nets_ema.style_encoder,
                               beta=0.999)

                # print('finish compute moving average of network parameters')

                # decay weight for diversity sensitive loss
                if args.lambda_ds > 0:
                    args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

                # print out log info
                if (i + 1) % args.print_every == 0:
                    elapsed = time.time() - start_time
                    elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                    log = "Elapsed time [%s], Iteration [%i/%i], " % (
                        elapsed, i + 1, args.total_iters)
                    all_losses = dict()
                    for loss, prefix in zip([
                            d_losses_latent, d_losses_ref, g_losses_latent,
                            g_losses_ref
                    ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                        for key, value in loss.items():
                            all_losses[prefix + key] = value
                    all_losses['G/lambda_ds'] = args.lambda_ds
                    log += ' '.join([
                        '%s: [%.4f]' % (key, value)
                        for key, value in all_losses.items()
                    ])
                    print(log)

                # generate images for debugging
                if (i + 1) % args.sample_every == 0:
                    os.makedirs(args.sample_dir, exist_ok=True)
                    utils.debug_image(nets_ema,
                                      args,
                                      inputs=inputs_val,
                                      step=i + 1)

                # save model checkpoints
                if (i + 1) % args.save_every == 0:
                    self._save_checkpoint(step=i + 1)
Beispiel #4
0
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        # fetch random validation images for debugging
        fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim,
                               'train')
        fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        inputs_val = next(fetcher_val)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, y_org = inputs.x_src, inputs.y_src
            x_ref, x_ref2, y_trg = inputs.x_ref, inputs.x_ref2, inputs.y_ref
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = compute_d_loss(nets,
                                                     args,
                                                     x_real,
                                                     y_org,
                                                     y_trg,
                                                     z_trg=z_trg,
                                                     masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = compute_d_loss(nets,
                                                  args,
                                                  x_real,
                                                  y_org,
                                                  y_trg,
                                                  x_ref=x_ref,
                                                  masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent = compute_g_loss(nets,
                                                     args,
                                                     x_real,
                                                     y_org,
                                                     y_trg,
                                                     z_trgs=[z_trg, z_trg2],
                                                     masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref = compute_g_loss(nets,
                                                  args,
                                                  x_real,
                                                  y_org,
                                                  y_trg,
                                                  x_refs=[x_ref, x_ref2],
                                                  masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network,
                           nets_ema.mapping_network,
                           beta=0.999)
            moving_average(nets.style_encoder,
                           nets_ema.style_encoder,
                           beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([
                        d_losses_latent, d_losses_ref, g_losses_latent,
                        g_losses_ref
                ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)

            # generate images for debugging
            if (i + 1) % args.sample_every == 0:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')

            import adaiw
            if (i + 1) % args.print_learned == 0:
                alphas_list = []
                for module in self.nets.generator.decode:
                    if isinstance(module.norm1, adaiw.AdaIN):
                        alphas_list.append([
                            ('norm_1_white', module.norm1.alpha_white_param,
                             'norm_1_color', module.norm1.alpha_color_param),
                            ('norm_2_white', module.norm2.alpha_white_param,
                             'norm_2_color', module.norm2.alpha_color_param)
                        ])

                txt_f = os.path.join(args.notes_path, "alphas.txt")
                with open(txt_f, "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    for item in alphas_list:
                        f.write("{}\n".format(item))

            if (i + 1) % args.print_std == 0:
                with open(os.path.join(args.notes_path, "std.txt"), "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    for j, module in enumerate(self.nets.generator.decode):
                        print([
                            module.std_b4_norm_1.item(),
                            module.std_b4_norm_2.item(),
                            module.std_b4_join.item(),
                            module.std_b4_output.item()
                        ],
                              file=f)

            if (i + 1) % args.print_sqrt_error == 0:
                with open(os.path.join(args.notes_path, "sqrt_errors.txt"),
                          "a+") as f:
                    f.write("{} iterations: \n".format(i + 1))
                    all_errors = []
                    for j, module in enumerate(self.nets.generator.decode):
                        errors = []
                        for norm in [module.norm1, module.norm2]:
                            if isinstance(norm, adaiw.AdaIN) and isinstance(
                                    norm.normalizer,
                                    adaiw.normalizer.Whitening):
                                errors.append(norm.normalizer.last_error)
                        all_errors.append(tuple(errors))
                    print(all_errors, file=f)

            if (i + 1) % args.print_color == 0:
                with torch.no_grad():
                    with open(os.path.join(args.notes_path, "last_color.txt"),
                              "a+") as f:
                        f.write("{} iterations: \n".format(i + 1))
                        for _, module in enumerate(self.nets.generator.decode):
                            if hasattr(module.norm1, 'last_injected_stat'):
                                diag1, tri1 = module.norm1.last_injected_stat.diagonal(
                                ), module.norm1.last_injected_stat.tril(-1)
                                print("Mean Diag:",
                                      diag1[0].mean().item(),
                                      "Std Diag:",
                                      diag1[0].std().item(),
                                      file=f)
                                print("Mean Tril:",
                                      tri1[0].mean().item(),
                                      "Std Tril:",
                                      tri1[0].std().item(),
                                      file=f)
                                diag2, tri2 = module.norm2.last_injected_stat.diagonal(
                                ), module.norm2.last_injected_stat.tril(-1)
                                print("Mean Diag:",
                                      diag2[0].mean().item(),
                                      "Std Diag:",
                                      diag2[0].std().item(),
                                      file=f)
                                print("Mean Tril:",
                                      tri2[0].mean().item(),
                                      "Std Tril:",
                                      tri2[0].std().item(),
                                      file=f)
                                print(file=f)
    def train(self, loaders):
        args = self.args
        nets = self.nets
        nets_ema = self.nets_ema
        optims = self.optims

        dbg = 0

        # fetch random validation images for debugging
        #fetcher = InputFetcher(loaders.src, loaders.ref, args.latent_dim, 'train')
        #fetcher_val = InputFetcher(loaders.val, None, args.latent_dim, 'val')
        #inputs_val = next(fetcher_val)

        fc2_loader, test_loader = loaders
        fetcher = FC2Fetcher(fc2_loader, None, args.latent_dim)
        inputs_val = next(fetcher)
        s_trg_path = os.getcwd() + "/s_trg_out"
        self.grid = None

        if not os.path.exists(s_trg_path):
            os.makedirs(s_trg_path)

        # resume training if necessary
        if args.resume_iter > 0:
            self._load_checkpoint(args.resume_iter)

        # remember the initial value of ds weight
        initial_lambda_ds = args.lambda_ds

        print('Start training...')
        start_time = time.time()
        for i in range(args.resume_iter, args.total_iters):
            # fetch images and labels
            inputs = next(fetcher)
            x_real, x_real2, y_org = inputs.x_src, inputs.x2_src, inputs.y_src
            x_ref, y_trg = inputs.x_ref, inputs.y_ref
            flow, mask, = inputs.flow, inputs.mask
            z_trg, z_trg2 = inputs.z_trg, inputs.z_trg2

            masks = nets.fan.get_heatmap(x_real) if args.w_hpf > 0 else None

            # train the discriminator
            d_loss, d_losses_latent = self.compute_d_loss(nets,
                                                          args,
                                                          x_real,
                                                          y_org,
                                                          y_trg,
                                                          z_trg=z_trg,
                                                          masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            d_loss, d_losses_ref = self.compute_d_loss(nets,
                                                       args,
                                                       x_real,
                                                       y_org,
                                                       y_trg,
                                                       x_ref=x_ref,
                                                       masks=masks)
            self._reset_grad()
            d_loss.backward()
            optims.discriminator.step()

            # train the generator
            g_loss, g_losses_latent, s_trg_lat = self.compute_g_loss(
                nets,
                args,
                x_real,
                x_real2,
                flow,
                mask,
                y_org,
                y_trg,
                z_trgs=[z_trg, z_trg2],
                masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()
            optims.mapping_network.step()
            optims.style_encoder.step()

            g_loss, g_losses_ref, s_trg_ref = self.compute_g_loss(
                nets,
                args,
                x_real,
                x_real2,
                flow,
                mask,
                y_org,
                y_trg,
                x_refs=[x_ref, None],
                masks=masks)
            self._reset_grad()
            g_loss.backward()
            optims.generator.step()

            # compute moving average of network parameters
            moving_average(nets.generator, nets_ema.generator, beta=0.999)
            moving_average(nets.mapping_network,
                           nets_ema.mapping_network,
                           beta=0.999)
            moving_average(nets.style_encoder,
                           nets_ema.style_encoder,
                           beta=0.999)

            # decay weight for diversity sensitive loss
            if args.lambda_ds > 0:
                args.lambda_ds -= (initial_lambda_ds / args.ds_iter)

            # print out log info
            if (i + 1) % args.print_every == 0 + dbg:
                elapsed = time.time() - start_time
                elapsed = str(datetime.timedelta(seconds=elapsed))[:-7]
                log = "Elapsed time [%s], Iteration [%i/%i], " % (
                    elapsed, i + 1, args.total_iters)
                all_losses = dict()
                for loss, prefix in zip([
                        d_losses_latent, d_losses_ref, g_losses_latent,
                        g_losses_ref
                ], ['D/latent_', 'D/ref_', 'G/latent_', 'G/ref_']):
                    for key, value in loss.items():
                        all_losses[prefix + key] = value
                all_losses['G/lambda_ds'] = args.lambda_ds
                log += ' '.join([
                    '%s: [%.4f]' % (key, value)
                    for key, value in all_losses.items()
                ])
                print(log)
                with open('losses.txt', 'a') as log_file:
                    log_file.write(log + '\n')

            # generate images for debugging
            if (i + 1) % args.sample_every == 0 + dbg:
                os.makedirs(args.sample_dir, exist_ok=True)
                utils.debug_image(nets_ema,
                                  args,
                                  inputs=inputs_val,
                                  step=i + 1)

            # save model checkpoints
            if (i + 1) % args.save_every == 0 + dbg:
                self._save_checkpoint(step=i + 1)

            # compute FID and LPIPS if necessary
            if (i + 1) % args.eval_every == 0:
                calculate_metrics(nets_ema, args, i + 1, mode='latent')
                calculate_metrics(nets_ema, args, i + 1, mode='reference')

            # save sintel debug
            if (i + 1) % args.sample_every == 0 + dbg:
                self.debugSintel(nets, args, i + 1, test_loader)
                self.debugSintel(nets, args, i + 1, test_loader,
                                 inputs_val.x_ref)

            if (i + 1) % 100 == 0 + dbg:
                s_trg_out = torch.cat((s_trg_lat, s_trg_ref), 0).numpy()
                np.save(s_trg_path + "/s_trg_" + str(i + 1) + ".npy",
                        s_trg_out)

            if dbg == 1:
                blah