if "bn" not in name and "output" not in name: if weights_vector is None: weights_vector = param.flatten() else: weights_vector = torch.cat((weights_vector, param.flatten()), 0) ewcData = maybe_cuda(ewcData, use_cuda=True) loss = (lambd / 2) * torch.dot(ewcData[1], (weights_vector - ewcData[0])**2) return loss def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv') != -1: m.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: m.weight.data.normal_(1.0, 0.02) m.bias.data.fill_(0) if __name__ == "__main__": from models.mobilenet import MyMobilenetV1 model = MyMobilenetV1(pretrained=True) replace_bn_with_brn(model, "net") ewcData, synData = create_syn_data(model) extract_weights(model, ewcData[0])
# do not remove this line start_time = time.time() # Create the dataset object dataset = CORE50( root='/home/admin/ssd_data/cvpr_competition/cvpr_competition_data/', scenario=scenario, preload=False) preproc = preprocess_imgs # Get the fixed test set full_valdidset = dataset.get_full_valid_set() # Model setup model = MyMobilenetV1(pretrained=True, latent_layer_num=latent_layer_num) replace_bn_with_brn(model, momentum=init_update_rate, r_d_max_inc_step=inc_step, max_r_max=max_r_max, max_d_max=max_d_max) model.saved_weights = {} model.past_j = {i: 0 for i in range(50)} model.cur_j = {i: 0 for i in range(50)} if reg_lambda != 0: ewcData, synData = create_syn_data(model) # Optimizer setup optimizer = torch.optim.SGD(model.parameters(), lr=init_lr, momentum=momentum,