def distillation_loss(self, logits, reprs, edges):
        """Calculate graph distillation losses, which include:
    regularization loss, loss for logits, and loss for representation.
    """
        # Regularization for graph distillation (average across batch)
        loss_reg = (edges.mean(1) - self.gd_prior).pow(2).sum() * self.gd_reg

        loss_logit, loss_repr = 0, 0
        for i, idx in enumerate(self.from_idx):
            w_distill = edges[i] + self.gd_prior[i]  # add graph prior
            loss_logit += self.w_losses[0] * utils.distance_metric(
                logits[self.to_idx], logits[idx], self.metric, w_distill)
            loss_repr += self.w_losses[1] * utils.distance_metric(
                reprs[self.to_idx], reprs[idx], self.metric, w_distill)
        return loss_reg, loss_logit, loss_repr
예제 #2
0
    def eval_target_images(self, netT, ims_np, opt_params, vis_epochs=10):
        n, nc, sz, sz_y = ims_np.shape
        assert (sz == sz_y), "Input must be square!"
        self.netZ = utils.latent_codes(self.net_params.nz, n)
        self.netZ.cuda()
        self.netT = netT
        self.dist = utils.distance_metric(sz, nc, self.net_params.force_l2)

        for epoch in range(opt_params.epochs):
            er = self.eval_epoch(epoch, ims_np, opt_params)
            print("NAM Eval Epoch: %d Error: %f" % (epoch, er))
            if epoch % vis_epochs == 0:
                self.visualize(epoch, ims_np, "nam_eval_ims")
예제 #3
0
def predict(args, model, data_loader):

    distance = distance_metric(args.distance, model)

    prediction_results = []
    with torch.no_grad():
        # each batch represent one episode (support data + query data)
        for i, (data, target) in enumerate(data_loader):
            data = data.to(args.device)

            # split data into support and query data
            support_input = data[:args.N_way * args.N_shot, :, :, :]
            query_input = data[args.N_way * args.N_shot:, :, :, :]

            # create the relative label (0 ~ N_way-1) for query data
            label_encoder = {
                target[i * args.N_shot]: i
                for i in range(args.N_way)
            }
            query_label = torch.cuda.LongTensor([
                label_encoder[class_name]
                for class_name in target[args.N_way * args.N_shot:]
            ])

            # support and query latent
            support_latent = model(support_input)
            query_latent = model(query_input)

            # prototype
            proto = support_latent.reshape(args.N_way, args.N_shot, -1)
            proto = proto.permute(1, 0, 2)  #(shot,way,z)
            proto_hallu = model.hallucinate(proto.mean(dim=0), args.hallu_m)

            proto_aug = torch.cat((proto, proto_hallu), dim=0)  #(shot+m,way,z)
            proto_aug = proto_aug.mean(dim=0)  #(way,z)
            #proto_aug = proto.mean(dim=0)

            # distances
            logits = distance(query_latent, proto_aug)

            _, indices = torch.topk(logits, k=1, dim=1)
            indices = indices.view(-1).cpu().numpy()

            prediction_results.append(indices)

            print(f"[{i}/{len(data_loader)}]", end="  \r")

    prediction_results = np.array(prediction_results)
    index = np.array([[i] for i in range(len(prediction_results))])
    prediction_results = np.concatenate((index, prediction_results), axis=1)
    return prediction_results
예제 #4
0
    def fit_to_target_domain(self, ims_np, opt_params, vis_epochs=10):
        n, nc, sz, sz_y = ims_np.shape
        assert (sz == sz_y), "Input must be square!"
        self.netZ = utils.encoderVAE(sz, self.net_params.nz, self.net_params.ndf, nc)
        self.netZ.cuda()

        self.netT = utils.transformer(sz, self.net_params.ngf, self.input_shape[2], nc)
        self.netT.cuda()

        self.dist = utils.distance_metric(sz, nc, self.net_params.force_l2)

        for epoch in range(opt_params.epochs):
            er, kl_er, rec_er = self.train_epoch(epoch, ims_np, opt_params)
            print("NAM Epoch: %d Error: %f KL: %f rec: %f" % (epoch, er, kl_er, rec_er))
            torch.save(self.netZ.state_dict(), 'nam_nets/netZ.pth')
            torch.save(self.netT.state_dict(), 'nam_nets/netT.pth')
            if epoch % vis_epochs == 0:
                self.visualize(epoch, ims_np, "nam_train_ims")
예제 #5
0
    def __init__(self, glo_params, vid_params, rn):
        self.netZ = model_video_orig._netZ(glo_params.nz, vid_params.n)
        self.netZ.apply(
            model_video_orig.weights_init)  # init the weights of the model
        self.netZ.cuda()  # on GPU
        self.rn = rn
        self.lr = 0.01
        self.data_loader = data_loader.DataLoader()
        self.netG = model_video_orig.netG_new(glo_params.nz)
        self.netG.apply(model_video_orig.weights_init)
        self.netG.cuda()

        num_devices = torch.cuda.device_count()
        if num_devices > 1:
            print("Using " + str(num_devices) + " GPU's")
            for i in range(num_devices):
                print(torch.cuda.get_device_name(i))
            self.netG = nn.DataParallel(self.netG)

        if load:  #Load point
            self.load_weights(counter, self.rn)

        self.vis_n = 100
        fixed_noise = torch.FloatTensor(self.vis_n, glo_params.nz).normal_(
            0, 1)  # for visualize func - Igen
        self.fixed_noise = fixed_noise.cuda()
        self.nag_params = glo_params
        self.vid_params = vid_params
        self.blockResnext = 101

        if VGG:
            self.dist_frame = utils.distance_metric(64, 3, glo_params.force_l2)
        elif LAP:
            self.lap_loss = lap.LapLoss(max_levels=3)
        else:
            self.dist = perceptual_loss_video._resnext_videoDistance(
                self.blockResnext)
예제 #6
0
    def __init__(self, glo_params, image_params, rn):
        self.netZ = model._netZ(glo_params.nz, image_params.n)
        self.netZ.apply(model.weights_init)
        self.netZ.cuda()
        self.rn = rn

        self.netG = model._netG(glo_params.nz, image_params.sz[0],
                                image_params.nc, glo_params.do_bn)
        self.netG.apply(model.weights_init)
        self.netG.cuda()
        # self.netG = nn.DataParallel(self.netG)

        self.vis_n = 64

        fixed_noise = torch.FloatTensor(self.vis_n,
                                        glo_params.nz).normal_(0, 1)
        self.fixed_noise = fixed_noise.cuda()

        self.glo_params = glo_params
        self.image_params = image_params

        # lap_criterion = pyr.MS_Lap(4, 5).cuda()
        self.dist = utils.distance_metric(image_params.sz[0], image_params.nc,
                                          glo_params.force_l2)
예제 #7
0
    valid_set = MiniImageNet_Dataset('../hw4_data/val/', valid_trans)
    valid_sampler = CategoriesSampler(valid_set.label,
                                      n_batch=args.n_batch,
                                      n_ways=args.valid_way,
                                      n_shot=args.shot + args.query)
    valid_loader = DataLoader(valid_set,
                              batch_sampler=valid_sampler,
                              num_workers=6,
                              worker_init_fn=worker_init_fn)

    # model
    model = Conv4_Hallu().to(args.device)
    model.train()

    # distance F
    distance = distance_metric(args.distance, model)

    # optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
    #lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
    steplr_after = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=len(train_loader) * args.epochs, eta_min=1e-5)
    lr_scheduler = WarmupScheduler(optimizer,
                                   multiplier=1,
                                   total_epoch=args.warmup_epochs *
                                   len(train_loader),
                                   after_scheduler=steplr_after)

    best_acc = 0
    train_loss, train_acc, valid_loss, valid_acc = [
        Averager() for i in range(4)