def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if NET_G != '':
            state_dict = \
                torch.load(NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', NET_G)
        if NET_D != '':
            state_dict = \
                torch.load(NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', NET_D)
        if CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD
示例#2
0
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CUDA:
            netG.cuda()
            netD.cuda()

        start_epoch = 0
        if cfg.NET_G != '':
            istart = cfg.NET_G.rfind('_') + 1
            iend = cfg.NET_G.rfind('.')
            start_epoch = cfg.NET_G[istart:iend]
            start_epoch = int(start_epoch) + 1

        return netG, netD, start_epoch
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G()
        netG.apply(weights_init)
        # print(netG)
        netD = STAGE1_D(self.new_arch)
        netD.apply(weights_init)
        # print(netD)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.CTModel != '':
            state_dict=\
                torch.load(cfg.CTModel,
                    map_location=lambda storage, loc: storage)
            self.CTallmodel.load_state_dict(state_dict)
            print('Load CT Model From: ', cfg.CTModel)
        # print(self.CTallmodel)

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
        return netG, netD
示例#4
0
    def load_network_stageI(self, text_dim, gf_dim, condition_dim, z_dim, df_dim):
        from model import STAGE1_G, STAGE1_D
        netG = STAGE1_G(text_dim, gf_dim, condition_dim, z_dim, self.cuda)
        netG.apply(weights_init)
        print(netG)
        netD = STAGE1_D(df_dim, condition_dim)
        netD.apply(weights_init)
        print(netD)

        if self.net_g != '':
            state_dict = torch.load(self.net_g)
            #torch.load(self.net_g, map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', self.net_g)
        if self.net_d != '':
            state_dict = torch.load(self.net_d, map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', self.net_d)
        if self.cuda:
            netG.cuda()
            netD.cuda()
        return netG, netD
示例#5
0
    def load_network_stageI(self):
        from model import STAGE1_G, STAGE1_D
        from model import EncoderRNN, LuongAttnDecoderRNN
        from model import STAGE1_ImageEncoder, EncodingDiscriminator

        netG = STAGE1_G()
        netG.apply(weights_init)
        #print(netG)
        netD = STAGE1_D()
        netD.apply(weights_init)
        #print(netD)

        emb_dim = 300
        encoder = EncoderRNN(emb_dim, self.txt_emb, 1, dropout=0.0)

        attn_model = 'general'
        decoder = LuongAttnDecoderRNN(attn_model,
                                      emb_dim,
                                      len(self.txt_dico.id2word),
                                      self.txt_emb,
                                      n_layers=1,
                                      dropout=0.0)

        image_encoder = STAGE1_ImageEncoder()
        image_encoder.apply(weights_init)

        e_disc = EncodingDiscriminator(emb_dim)

        if cfg.NET_G != '':
            state_dict = \
                torch.load(cfg.NET_G,
                           map_location=lambda storage, loc: storage)
            netG.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_G)
        if cfg.NET_D != '':
            state_dict = \
                torch.load(cfg.NET_D,
                           map_location=lambda storage, loc: storage)
            netD.load_state_dict(state_dict)
            print('Load from: ', cfg.NET_D)
        if cfg.ENCODER != '':
            state_dict = \
                torch.load(cfg.ENCODER,
                           map_location=lambda storage, loc: storage)
            encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.ENCODER)
        if cfg.DECODER != '':
            state_dict = \
                torch.load(cfg.DECODER,
                           map_location=lambda storage, loc: storage)
            decoder.load_state_dict(state_dict)
            print('Load from: ', cfg.DECODER)
        if cfg.IMAGE_ENCODER != '':
            state_dict = \
                torch.load(cfg.IMAGE_ENCODER,
                           map_location=lambda storage, loc: storage)
            image_encoder.load_state_dict(state_dict)
            print('Load from: ', cfg.IMAGE_ENCODER)

        # load classification model and freeze weights
        #clf_model = models.alexnet(pretrained=True)
        clf_model = models.vgg16(pretrained=True)
        for param in clf_model.parameters():
            param.requires_grad = False

        if cfg.CUDA:
            netG.cuda()
            netD.cuda()
            encoder.cuda()
            decoder.cuda()
            image_encoder.cuda()
            e_disc.cuda()
            clf_model.cuda()

#         ## finetune model for a bit???
#         output_size = 512 * 2 * 2
#         num_classes = 200
#         clf_model.classifier = nn.Sequential(
#             nn.Linear(output_size, 1024, bias=True),
#             nn.LeakyReLU(0.2),
#             nn.Dropout(0.5),
#             nn.Linear(1024, num_classes, bias=True)
#             )

#         clf_optim = optim.SGD(clf_model.parameters(), lr=1e-2, momentum=0.9)

        return netG, netD, encoder, decoder, image_encoder, e_disc, clf_model
示例#6
0
        z = self.reparameterize(mu, logvar)

        #z_p = self.decode_lin(z).view(-1, self.nef, 4, 4)
        #img = self.decode_img(z_p)

        #z = self.encode_img(x).view(-1, 480 * 4 * 4)
        img = self.decode(z)
        # mu = None
        # logvar = None

        return img, mu, logvar


model = VAE()

netD = STAGE1_D()

if args.cuda:
    model.cuda()
    netD.cuda()

model_params = filter(lambda p: p.requires_grad, model.parameters())
optimizer = optim.Adam(model_params, lr=1e-3)
optd = optim.Adam(netD.parameters(), lr=1e-4)


# Reconstruction + KL divergence losses summed over all elements and batch
def loss_function(recon_x, x, mu, logvar):
    #pdb.set_trace()
    #BCE = F.binary_cross_entropy(recon_x, x, size_average=False)
    BCE = F.smooth_l1_loss(recon_x, F.sigmoid(x), size_average=False)