def __init__(self): self.best_accuracy = 0.0 self.adjust_learning_rate = Config.adjust_learning_rate # all data self.data_train = MiniImageNetDataset.get_data_all(Config.data_root) self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot) self.task_train_loader = DataLoader(self.task_train, Config.batch_size, shuffle=True, num_workers=Config.num_workers) # model self.proto_net = RunnerTool.to_cuda(Config.proto_net) RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init)) self.loss_ce = RunnerTool.to_cuda(nn.CrossEntropyLoss()) self.loss_mse = RunnerTool.to_cuda(nn.MSELoss()) # optim self.proto_net_optim = torch.optim.SGD(self.proto_net.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4) self.test_tool = TestTool(self.proto_test, data_root=Config.data_root, num_way=Config.num_way, num_shot=Config.num_shot, episode_size=Config.episode_size, test_episode=Config.test_episode, transform=self.task_train.transform_test) pass
def __init__(self): self.best_accuracy = 0.0 # all data self.data_train = MiniImageNetDataset.get_data_all(Config.data_root) self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot) self.task_train_loader = DataLoader(self.task_train, Config.batch_size, shuffle=True, num_workers=Config.num_workers) # model self.proto_net = RunnerTool.to_cuda(Config.proto_net) RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init)) # optim self.proto_net_optim = torch.optim.Adam(self.proto_net.parameters(), lr=Config.learning_rate) self.proto_net_scheduler = StepLR(self.proto_net_optim, Config.train_epoch // 3, gamma=0.5) self.test_tool = TestTool(self.proto_test, data_root=Config.data_root, num_way=Config.num_way, num_shot=Config.num_shot, episode_size=Config.episode_size, test_episode=Config.test_episode, transform=self.task_train.transform_test) pass
def train(self): Tools.print() Tools.print("Training...") for epoch in range(1, 1 + Config.train_epoch): self.proto_net.train() Tools.print() all_loss = 0.0 self.adjust_learning_rate(epoch=epoch) Tools.print("{:6} lr:{}".format( epoch, self.proto_net_optim.param_groups[0]["lr"])) for task_data, task_labels, task_index in tqdm( self.task_train_loader): task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) # 1 calculate features log_p_y = self.proto(task_data) # 2 loss loss = -(log_p_y * task_labels).sum() / task_labels.sum() all_loss += loss.item() # 3 backward self.proto_net.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5) self.proto_net_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f}".format( epoch, all_loss / len(self.task_train_loader))) ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: Tools.print() Tools.print("Test {} {} .......".format( epoch, Config.model_name)) self.proto_net.eval() val_accuracy = self.test_tool.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def __init__(self): self.best_accuracy = 0.0 self.adjust_learning_rate = Config.adjust_learning_rate # all data self.data_train = MiniImageNetDataset.get_data_all(Config.data_root) self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot) self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers) # IC self.produce_class = ProduceClass(len(self.data_train), Config.ic_out_dim, Config.ic_ratio) self.produce_class.init() self.task_train.set_samples_class(self.produce_class.classes) self.task_train.set_samples_feature(self.produce_class.features) # model self.proto_net = RunnerTool.to_cuda(Config.proto_net) self.ic_model = RunnerTool.to_cuda(ICResNet(low_dim=Config.ic_out_dim)) RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init)) RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init)) # optim self.proto_net_optim = torch.optim.SGD(self.proto_net.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4) self.ic_model_optim = torch.optim.SGD(self.ic_model.parameters(), lr=Config.learning_rate, momentum=0.9, weight_decay=5e-4) # loss self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss()) # Eval self.test_tool_fsl = TestTool(self.proto_test, data_root=Config.data_root, num_way=Config.num_way, num_shot=Config.num_shot, episode_size=Config.episode_size, test_episode=Config.test_episode, transform=self.task_train.transform_test) self.test_tool_ic = ICTestTool(feature_encoder=None, ic_model=self.ic_model, data_root=Config.data_root, batch_size=Config.batch_size, num_workers=Config.num_workers, ic_out_dim=Config.ic_out_dim) pass
def __init__(self): self.best_accuracy = 0.0 # all data self.data_train = MiniImageNetDataset.get_data_all(Config.data_root) self.task_train = MiniImageNetDataset(self.data_train, Config.num_way, Config.num_shot) self.task_train_loader = DataLoader(self.task_train, Config.batch_size, True, num_workers=Config.num_workers) # IC self.produce_class = ProduceClass(len(self.data_train), Config.ic_out_dim, Config.ic_ratio) self.produce_class.init() # model self.proto_net = RunnerTool.to_cuda(Config.proto_net) self.ic_model = RunnerTool.to_cuda(Config.ic_proto_net) RunnerTool.to_cuda(self.proto_net.apply(RunnerTool.weights_init)) RunnerTool.to_cuda(self.ic_model.apply(RunnerTool.weights_init)) # optim self.proto_net_optim = torch.optim.Adam(self.proto_net.parameters(), lr=Config.learning_rate) self.ic_model_optim = torch.optim.Adam(self.ic_model.parameters(), lr=Config.learning_rate) self.proto_net_scheduler = StepLR(self.proto_net_optim, Config.train_epoch // 3, gamma=0.5) self.ic_model_scheduler = StepLR(self.ic_model_optim, Config.train_epoch // 3, gamma=0.5) # loss self.ic_loss = RunnerTool.to_cuda(nn.CrossEntropyLoss()) # Eval self.test_tool_fsl = TestTool(self.proto_test, data_root=Config.data_root, num_way=Config.num_way, num_shot=Config.num_shot, episode_size=Config.episode_size, test_episode=Config.test_episode, transform=self.task_train.transform_test) self.test_tool_ic = ICTestTool(feature_encoder=self.proto_net, ic_model=self.ic_model, data_root=Config.data_root, batch_size=Config.batch_size, num_workers=Config.num_workers, ic_out_dim=Config.ic_out_dim) pass
def train(self): Tools.print() Tools.print("Training...") for epoch in range(1, 1 + Config.train_epoch): self.proto_net.train() Tools.print() all_loss = 0.0 pn_lr = self.adjust_learning_rate(self.proto_net_optim, epoch, Config.first_epoch, Config.t_epoch, Config.learning_rate) Tools.print('Epoch: [{}] pn_lr={}'.format(epoch, pn_lr)) for task_data, task_labels, task_index in tqdm( self.task_train_loader): task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) # 1 calculate features dists = self.proto(task_data) # 2 loss if Config.is_mse: targets = -(task_labels - 1) loss = self.loss_mse(dists, targets) else: targets = torch.argmax(task_labels, dim=1) // Config.num_shot loss = self.loss_ce(-dists, targets) all_loss += loss.item() # 3 backward self.proto_net.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5) self.proto_net_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f}".format( epoch, all_loss / len(self.task_train_loader))) ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: Tools.print() Tools.print("Test {} {} .......".format( epoch, Config.model_name)) self.proto_net.eval() val_accuracy = self.test_tool.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def train(self): Tools.print() Tools.print("Training...") # Init Update # try: # self.proto_net.eval() # self.ic_model.eval() # Tools.print("Init label {} .......") # self.produce_class.reset() # with torch.no_grad(): # for task_data, task_labels, task_index in tqdm(self.task_train_loader): # ic_labels = RunnerTool.to_cuda(task_index[:, -1]) # task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels) # log_p_y, query_features = self.proto(task_data) # ic_out_logits, ic_out_l2norm = self.ic_model(query_features) # self.produce_class.cal_label(ic_out_l2norm, ic_labels) # pass # pass # Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2)) # finally: # pass for epoch in range(Config.train_epoch): self.proto_net.train() self.ic_model.train() Tools.print() self.produce_class.reset() all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0 for task_data, task_labels, task_index in tqdm(self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda(task_data), RunnerTool.to_cuda(task_labels) ########################################################################### # 1 calculate features log_p_y, query_features = self.proto(task_data) ic_out_logits, ic_out_l2norm = self.ic_model(query_features) # 2 ic_targets = self.produce_class.get_label(ic_labels) self.produce_class.cal_label(ic_out_l2norm, ic_labels) # 3 loss loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum() * Config.loss_fsl_ratio loss_ic = self.ic_loss(ic_out_logits, ic_targets) * Config.loss_ic_ratio loss = loss_fsl + loss_ic all_loss += loss.item() all_loss_fsl += loss_fsl.item() all_loss_ic += loss_ic.item() # 4 backward self.proto_net.zero_grad() self.ic_model.zero_grad() loss.backward() # torch.nn.utils.clip_grad_norm_(self.proto_net.parameters(), 0.5) # torch.nn.utils.clip_grad_norm_(self.ic_model.parameters(), 0.5) self.proto_net_optim.step() self.ic_model_optim.step() ########################################################################### pass ########################################################################### # print Tools.print("{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} lr:{}".format( epoch + 1, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader), all_loss_ic / len(self.task_train_loader), self.proto_net_scheduler.get_last_lr())) Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2)) self.proto_net_scheduler.step() self.ic_model_scheduler.step() ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: self.proto_net.eval() self.ic_model.eval() self.test_tool_ic.val(epoch=epoch) val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) torch.save(self.ic_model.state_dict(), Config.ic_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def train(self): Tools.print() Tools.print("Training...") # Init Update if False: self.ic_model.eval() Tools.print("Init label {} .......") self.produce_class.reset() for task_data, task_labels, task_index, task_ok in tqdm( self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1]) self.produce_class.cal_label(ic_out_l2norm, ic_labels) pass Tools.print("Epoch: {}/{}".format(self.produce_class.count, self.produce_class.count_2)) pass for epoch in range(1, 1 + Config.train_epoch): self.proto_net.train() self.ic_model.train() Tools.print() pn_lr = self.adjust_learning_rate(self.proto_net_optim, epoch, Config.first_epoch, Config.t_epoch, Config.learning_rate) ic_lr = self.adjust_learning_rate(self.ic_model_optim, epoch, Config.first_epoch, Config.t_epoch, Config.learning_rate) Tools.print('Epoch: [{}] pn_lr={} ic_lr={}'.format( epoch, pn_lr, ic_lr)) self.produce_class.reset() Tools.print(self.task_train.classes) is_ok_total, is_ok_acc = 0, 0 all_loss, all_loss_fsl, all_loss_ic = 0.0, 0.0, 0.0 for task_data, task_labels, task_index, task_ok in tqdm( self.task_train_loader): ic_labels = RunnerTool.to_cuda(task_index[:, -1]) task_data, task_labels = RunnerTool.to_cuda( task_data), RunnerTool.to_cuda(task_labels) ########################################################################### # 1 calculate features log_p_y = self.proto(task_data) ic_out_logits, ic_out_l2norm = self.ic_model(task_data[:, -1]) # 2 ic_targets = self.produce_class.get_label(ic_labels) self.produce_class.cal_label(ic_out_l2norm, ic_labels) # 3 loss loss_fsl = -(log_p_y * task_labels).sum() / task_labels.sum() loss_ic = self.ic_loss(ic_out_logits, ic_targets) loss = loss_fsl * Config.loss_fsl_ratio + loss_ic * Config.loss_ic_ratio all_loss += loss.item() all_loss_fsl += loss_fsl.item() all_loss_ic += loss_ic.item() # 4 backward if Config.train_ic: self.ic_model.zero_grad() loss_ic.backward() self.ic_model_optim.step() pass self.proto_net.zero_grad() loss_fsl.backward() self.proto_net_optim.step() # is ok is_ok_acc += torch.sum(torch.cat(task_ok)) is_ok_total += torch.prod( torch.tensor(torch.cat(task_ok).shape)) ########################################################################### pass ########################################################################### # print Tools.print( "{:6} loss:{:.3f} fsl:{:.3f} ic:{:.3f} ok:{:.3f}({}/{})". format( epoch, all_loss / len(self.task_train_loader), all_loss_fsl / len(self.task_train_loader), all_loss_ic / len(self.task_train_loader), int(is_ok_acc) / int(is_ok_total), is_ok_acc, is_ok_total, )) Tools.print("Train: [{}] {}/{}".format(epoch, self.produce_class.count, self.produce_class.count_2)) ########################################################################### ########################################################################### # Val if epoch % Config.val_freq == 0: self.proto_net.eval() self.ic_model.eval() self.test_tool_ic.val(epoch=epoch) val_accuracy = self.test_tool_fsl.val(episode=epoch, is_print=True) if val_accuracy > self.best_accuracy: self.best_accuracy = val_accuracy torch.save(self.proto_net.state_dict(), Config.pn_dir) torch.save(self.ic_model.state_dict(), Config.ic_dir) Tools.print("Save networks for epoch: {}".format(epoch)) pass pass ########################################################################### pass pass
def __init__(self): self.proto_net = RunnerTool.to_cuda(Config.proto_net) pass