Exemplo n.º 1
0
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.classifier = Classifier(opt.model)  #.cuda(device=opt.device)
        #####################
        #    Init weights
        #####################
        # self.classifier.apply(weights_init)

        print_network(self.classifier)

        self.optimizer = get_optimizer(opt, self.classifier)
        self.scheduler = get_scheduler(opt, self.optimizer)

        # load networks
        # if opt.load:
        #     pretrained_path = opt.load
        #     self.load_network(self.classifier, 'G', opt.which_epoch, pretrained_path)
        # if self.training:
        #     self.load_network(self.discriminitor, 'D', opt.which_epoch, pretrained_path)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

        # with open('datasets/class_weight.pkl', 'rb') as f:
        #     class_weight = pickle.load(f, encoding='bytes')
        #     class_weight = np.array(class_weight, dtype=np.float32)
        #     class_weight = torch.from_numpy(class_weight).to(opt.device)
        #     if opt.class_weight:
        #         self.criterionCE = nn.CrossEntropyLoss(weight=class_weight)
        #     else:
        self.criterionCE = nn.CrossEntropyLoss()
Exemplo n.º 2
0
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.classifier = Classifier()  #.cuda(device=opt.device)
        #####################
        #    Init weights
        #####################
        # self.classifier.apply(weights_init)

        print_network(self.classifier)

        self.optimizer = optim.Adam(self.classifier.parameters(),
                                    lr=opt.lr,
                                    betas=(0.95, 0.999))

        # load networks
        if opt.load:
            pretrained_path = opt.load
            self.load_network(self.classifier, 'G', opt.which_epoch,
                              pretrained_path)
            # if self.training:
            #     self.load_network(self.discriminitor, 'D', opt.which_epoch, pretrained_path)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)
Exemplo n.º 3
0
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.classifier = Classifier()
        #####################
        #    Init weights
        #####################
        # self.classifier.apply(weights_init)

        print_network(self.classifier)

        self.optimizer = optim.Adam(self.classifier.parameters(),
                                    lr=opt.lr,
                                    betas=(0.95, 0.999))
        self.scheduler = get_scheduler(opt, self.optimizer)

        # load networks
        # if opt.load:
        #     pretrained_path = opt.load
        #     self.load_network(self.classifier, 'G', opt.which_epoch, pretrained_path)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)
Exemplo n.º 4
0
    def __init__(self, opt):
        super(Model, self).__init__()
        self.opt = opt
        self.direct_feature = DirectFeature(opt.model)
        self.feature_nums = self.direct_feature.get_feature_num()
        self.meta_embedding = MetaEmbedding(self.feature_nums, 50030)

        print_network(self.direct_feature)
        print_network(self.meta_embedding)

        # TODO: 这里学习率是不是可以调成 direct_feature 0.01 meta_embedding 0.1
        # self.optimizer = optim.SGD(chain(self.direct_feature.parameters(), self.meta_embedding.parameters()),
        #                            lr=0.01, momentum=0.9, weight_decay=0.0005)

        self.optimizer = optim.Adam(chain(self.direct_feature.parameters(), self.meta_embedding.parameters()),
                                   lr=0.01)

        self.scheduler = get_scheduler(opt, self.optimizer)

        self.avg_meters = ExponentialMovingAverage(0.95)
        self.save_dir = os.path.join(opt.checkpoint_dir, opt.tag)

        # different weight for different classes
        self.criterionCE = nn.CrossEntropyLoss()