Пример #1
0
class IndependantSampler(Sampler):
    def __init__(self, p):
        super(IndependantSampler, self).__init__()
        self.sampler = OneHotCategorical(p)
        self.p = p

    def sample(self, data):
        n = data.size(0)
        return self.sampler.sample_n(n).permute(0, 2, 1)
Пример #2
0
            'params': ae.dec.conv3.parameters(),
            'lr': args.lr * 5
        },
    ],
                                 lr=args.lr)

    optimizer = torch.optim.Adam(ae.parameters(), lr=args.lr)
    loss_func = nn.MSELoss()
    t1 = time.time()
    ploss1 = ploss2 = ploss3 = ploss4 = ploss5 = torch.FloatTensor(0).cuda()
    for epoch in range(args.epoch):
        for step in range(args.num_step_per_epoch):
            # Generate codes randomly
            one_hot = OneHotCategorical([1. / args.num_class] * args.num_class)
            x = torch.randn([args.batch_size, args.num_class
                             ]) + one_hot.sample_n(args.batch_size)  # logits
            x = x.cuda()
            prob_gt = nn.functional.softmax(x, dim=1)  # prob, ground truth

            # forward
            if args.mode == "BD":
                feats1, feats2 = ae(x)
            elif args.mode == "SE":
                feats1, small_feats1, feats2 = ae(
                    x
                )  # feats1: feats from encoder. small_feats1: feats from small encoder. feats2: feats from encoder.

            # code loss: cross entropy
            if args.mode == "BD":
                logits1 = feats1[-1]
                logprob_1 = nn.functional.log_softmax(logits1, dim=1)
Пример #3
0
   for clip in os.path.basename(args.e2).split("_"):
     if clip[0] == "E" and "S" in clip:
       num1 = clip.split("E")[1].split("S")[0]
       num2 = clip.split("S")[1]
       if num1.isdigit() and num2.isdigit():
         previous_epoch = int(num1)
         previous_step  = int(num2)
 
 # Optimization
 t1 = time.time()
 for epoch in range(previous_epoch, args.num_epoch):
   for step, (img, label) in enumerate(train_loader):
     ae.train()
     # Generate codes randomly
     if args.use_pseudo_code:
       onehot_label = one_hot.sample_n(args.batch_size)
       x = torch.randn([args.batch_size, args.num_class]) * (np.random.rand() * 5.0 + 2.0) + onehot_label * np.random.randint(args.end, args.begin) # logits
       x = x.cuda() / args.Temp
       label = onehot_label.data.numpy().argmax(axis=1)
       label = torch.from_numpy(label).long()
     else:
       x = ae.be(img.cuda()) / args.Temp
     prob_gt = F.softmax(x, dim=1) # prob, ground truth
     label = label.cuda()
     
     if args.adv_train == 3:
       # update decoder
       imgrec = []; imgrec_DT = []; hardloss_dec = []; trainacc_dec = []; ave_imgrec = 0
       for di in range(1, args.num_dec+1):
         dec = eval("ae.d" + str(di)); optimizer = optimizer_dec[di-1]; ema = ema_dec[di-1]
         dec.zero_grad()
Пример #4
0
            ae.train()
            imgrec_all = []
            logits_all = []
            imgrec_DT_all = []
            hardloss_dec_all = []
            trainacc_dec_all = []
            actimax_loss_print = []

            if args.input == "pseudo_image":  # use artificially created data as input
                if args.lw_msgan or args.lw_msgan_feat:
                    half_bs = int(args.batch_size / 2)
                    random_z1 = torch.randn(half_bs, args.num_z).cuda()
                    random_z2 = torch.randn(half_bs, args.num_z).cuda()
                    x = torch.cat([random_z1, random_z2], dim=0)
                    if args.use_condition:
                        onehot_label = one_hot.sample_n(half_bs).view(
                            [half_bs, args.num_class]).cuda()
                        label_concat = torch.cat([onehot_label, onehot_label],
                                                 dim=0)
                        label = label_concat.argmax(dim=1).detach()
                        x = torch.cat([x, label_concat], dim=1).detach()
                        # label_noise = torch.randn(args.batch_size, args.num_class).cuda() * args.begin
                        # label = label_noise.argmax(dim=1).detach()
                        # for i in range(args.batch_size):
                        # label_noise[i, label[i]] += 5
                        # x = torch.cat([x, label_noise], dim=1).detach()
                else:
                    x = torch.randn(args.batch_size, args.num_z).cuda()
                    if args.use_condition:
                        onehot_label = one_hot.sample_n(args.batch_size).view(
                            [args.batch_size, args.num_class]).cuda()
                        label = onehot_label.argmax(dim=1).detach()