def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( ) # Step 2: init neural networks print("init neural networks") feature_encoder = CNNEncoder() relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM) feature_encoder.apply(weights_init) relation_network.apply(weights_init) feature_encoder.cuda(GPU) relation_network.cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) # if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): # feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) # print("load feature encoder success") # if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): # relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) # print("load relation network success") # Step 3: build graph print("Training...") last_accuracy = 0.0 for episode in range(EPISODE): feature_encoder_scheduler.step(episode) relation_network_scheduler.step(episode) # init dataset # sample_dataloader is to obtain previous samples for compare # batch_dataloader is to batch samples for training degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) batch_dataloader = tg.get_data_loader( task, num_per_class=BATCH_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) # sample datas samples, sample_labels = sample_dataloader.__iter__().next() batches, batch_labels = batch_dataloader.__iter__().next() # calculate features sample_features = feature_encoder( Variable(samples).cuda(GPU)) # 5x64*5*5 sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) sample_features = torch.sum(sample_features, 1) / 5.0 sample_features = sample_features.squeeze(1) batch_features = feature_encoder( Variable(batches).cuda(GPU)) # 20x64*5*5 # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) batch_features_ext = batch_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) batch_features_ext = torch.transpose(batch_features_ext, 0, 1) relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) relations = relation_network(relation_pairs).view(-1, CLASS_NUM) mse = nn.MSELoss().cuda(GPU) one_hot_labels = Variable( torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)).cuda(GPU) loss = mse(relations, one_hot_labels) # training feature_encoder.zero_grad() relation_network.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5) feature_encoder_optim.step() relation_network_optim.step() if (episode + 1) % 100 == 0: print("episode:", episode + 1, "loss", loss.item()) if (episode + 1) % 5000 == 0: # test print("Testing...") total_rewards = 0 for i in range(TEST_EPISODE): degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask( metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS, ) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = sample_dataloader.__iter__( ).next() test_images, test_labels = test_dataloader.__iter__().next() # calculate features sample_features = feature_encoder( Variable(sample_images).cuda(GPU)) # 5x64 sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) sample_features = torch.sum(sample_features, 1) / 5.0 sample_features = sample_features.squeeze(1) test_features = feature_encoder( Variable(test_images).cuda(GPU)) # 20x64 # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) predict_labels = predict_labels.cuda(GPU) test_labels = test_labels.cuda(GPU) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS) ] total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE print("test accuracy:", test_accuracy) if test_accuracy > last_accuracy: # save networks torch.save( feature_encoder.state_dict(), str("./omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) torch.save( relation_network.state_dict(), str("./omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) print("save networks for episode:", episode) last_accuracy = test_accuracy
def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders,metatest_character_folders = tg.omniglot_character_folders() # Step 2: init neural networks print("init neural networks") feature_encoder = CNNEncoder() relation_network = RelationNetwork(FEATURE_DIM,RELATION_DIM) feature_encoder.cuda(GPU) relation_network.cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=100000,gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim,step_size=100000,gamma=0.5) if os.path.exists(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): feature_encoder.load_state_dict(torch.load(str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) print("load feature encoder success") if os.path.exists(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl")): relation_network.load_state_dict(torch.load(str("./models/omniglot_relation_network_"+ str(CLASS_NUM) +"way_" + str(SAMPLE_NUM_PER_CLASS) +"shot.pkl"))) print("load relation network success") total_accuracy = 0.0 for episode in range(EPISODE): # test print("Testing...") total_rewards = 0 accuracies = [] for i in range(TEST_EPISODE): degrees = random.choice([0,90,180,270]) task = tg.OmniglotTask(metatest_character_folders,CLASS_NUM,SAMPLE_NUM_PER_CLASS,SAMPLE_NUM_PER_CLASS,) sample_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees) test_dataloader = tg.get_data_loader(task,num_per_class=SAMPLE_NUM_PER_CLASS,split="test",shuffle=True,rotation=degrees) sample_images,sample_labels = sample_dataloader.__iter__().next() test_images,test_labels = test_dataloader.__iter__().next() # calculate features sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64 test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64 # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) test_features_ext = torch.transpose(test_features_ext,0,1) relation_pairs = torch.cat((sample_features_ext,test_features_ext),2).view(-1,FEATURE_DIM*2,5,5) relations = relation_network(relation_pairs).view(-1,CLASS_NUM) _,predict_labels = torch.max(relations.data,1) # start change use_cuda = torch.cuda.is_available() device = torch.device('cuda:0' if use_cuda else 'cpu') test_labels = test_labels.to(device) rewards = [1 if predict_labels[j]==test_labels[j] else 0 for j in range(CLASS_NUM)] total_rewards += np.sum(rewards) accuracy = np.sum(rewards)/1.0/CLASS_NUM/SAMPLE_NUM_PER_CLASS accuracies.append(accuracy) test_accuracy,h = mean_confidence_interval(accuracies) print("test accuracy:",test_accuracy,"h:",h) total_accuracy += test_accuracy print("aver_accuracy:",total_accuracy/EPISODE)
def test(feature_encoder, relation_network): metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( ) # test print("Testing...") total_rewards = 0 for i in range(TEST_EPISODE): # degrees = random.choice([0, 90, 180, 270]) degrees = random.choice([0]) task = tg.OmniglotTask( metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS, ) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = sample_dataloader.__iter__().next() test_images, test_labels = test_dataloader.__iter__().next() # calculate features sample_features = feature_encoder( Variable(sample_images).to(device)) # 5x64 sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder( Variable(test_images).to(device)) # 20x64 # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) relations = relation_network(relation_pairs).view(-1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS) ] total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS / TEST_EPISODE print("test accuracy:", test_accuracy)
def main(): writer = SummaryWriter('/home/caffe/achu/logs/pytorch_omniglot_FSL.log') # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( data_folder=args.dataset_folder, no_of_training_samples=args.training_samples_per_class, no_of_validation_samples=args.support_set_samples_per_class) # Step 2: init neural networks print("init neural networks") cnn_output_dims = cnn_final_output_dims(args.image_size) rn_dims = rn_dims_before_FCN(cnn_output_dims) fcn_size = args.channel_dim * (rn_dims**2) feature_encoder = CNNEncoder() relation_network = RelationNetwork(fcn_size, args.hidden_unit) feature_encoder.apply(weights_init) relation_network.apply(weights_init) if torch.cuda.device_count() >= 1: print("Let's use", torch.cuda.device_count(), "args.gpus!") feature_encoder = nn.DataParallel(feature_encoder) relation_network = nn.DataParallel(relation_network) else: feature_encoder.cuda(args.gpu) relation_network.cuda(args.gpu) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=args.learning_rate) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=args.learning_rate) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) if os.path.exists( str("./models/omniglot_feature_encoder_" + str(args.class_num) + "way_" + str(args.training_samples_per_class) + "shot.pkl")): feature_encoder.load_state_dict( torch.load(str("./models/omniglot_feature_encoder_" + str(args.class_num) + "way_" + str(args.training_samples_per_class) + "shot.pkl"), map_location='cuda:0')) print("load feature encoder success") if os.path.exists( str("./models/omniglot_relation_network_" + str(args.class_num) + "way_" + str(args.training_samples_per_class) + "shot.pkl")): relation_network.load_state_dict( torch.load(str("./models/omniglot_relation_network_" + str(args.class_num) + "way_" + str(args.training_samples_per_class) + "shot.pkl"), map_location='cuda:0')) print("load relation network success") # Step 3: build graph print("Training...") last_accuracy = 0.0 for episode in range(args.episode): feature_encoder_scheduler.step(episode) relation_network_scheduler.step(episode) # init dataset # sample_dataloader is to obtain previous samples for compare # batch_dataloader is to batch samples for training degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask(metatrain_character_folders, args.class_num, args.training_samples_per_class, args.support_set_samples_per_class) sample_dataloader = tg.get_data_loader( task, num_per_class=args.training_samples_per_class, split="train", shuffle=False, rotation=degrees) batch_dataloader = tg.get_data_loader( task, num_per_class=args.support_set_samples_per_class, split="test", shuffle=True, rotation=degrees) # sample datas samples, sample_labels = sample_dataloader.__iter__().next() batches, batch_labels = batch_dataloader.__iter__().next() # calculate features sample_features = feature_encoder(Variable(samples).cuda( args.gpu)) # 5x64*5*5 sample_features = sample_features.view(args.class_num, args.training_samples_per_class, args.feature_dim, cnn_output_dims, cnn_output_dims) sample_features = torch.sum(sample_features, 1).squeeze(1) batch_features = feature_encoder(Variable(batches).cuda( args.gpu)) # 20x64*5*5 # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( args.support_set_samples_per_class * args.class_num, 1, 1, 1, 1) batch_features_ext = batch_features.unsqueeze(0).repeat( args.class_num, 1, 1, 1, 1) batch_features_ext = torch.transpose(batch_features_ext, 0, 1) relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, args.feature_dim * 2, cnn_output_dims, cnn_output_dims) relations = relation_network(relation_pairs).view(-1, args.class_num) mse = nn.MSELoss().cuda(args.gpu) one_hot_labels = Variable( torch.zeros(args.support_set_samples_per_class * args.class_num, args.class_num).scatter_(1, batch_labels.view(-1, 1), 1)).cuda(args.gpu) loss = mse(relations, one_hot_labels) # training feature_encoder.zero_grad() relation_network.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5) feature_encoder_optim.step() relation_network_optim.step() writer.add_scalar('Training loss', loss.data, episode + 1) if (episode + 1) % 100 == 0: print("episode:", episode + 1, "loss", loss.data) if (episode + 1) % 5000 == 0: # test print("Testing...") total_rewards = 0 for i in range(args.test_episode): degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask( metatest_character_folders, args.class_num, args.training_samples_per_class, args.training_samples_per_class, ) sample_dataloader = tg.get_data_loader( task, num_per_class=args.training_samples_per_class, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=args.support_set_samples_per_class, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = sample_dataloader.__iter__( ).next() test_images, test_labels = test_dataloader.__iter__().next() test_labels = test_labels.cuda() # calculate features sample_features = feature_encoder( Variable(sample_images).cuda(args.gpu)) # 5x64 sample_features = sample_features.view( args.class_num, args.training_samples_per_class, args.feature_dim, cnn_output_dims, cnn_output_dims) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder( Variable(test_images).cuda(args.gpu)) # 20x64 # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( args.training_samples_per_class * args.class_num, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( args.class_num, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, args.feature_dim * 2, cnn_output_dims, cnn_output_dims) relations = relation_network(relation_pairs).view( -1, args.class_num) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(args.class_num * args.training_samples_per_class) ] total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / args.class_num / args.training_samples_per_class / args.test_episode print("validation accuracy:", test_accuracy) writer.add_scalar('Validation accuracy', test_accuracy, episode + 1) if test_accuracy > last_accuracy: # save networks torch.save( feature_encoder.state_dict(), str("./models/omniglot_feature_encoder_" + str(args.class_num) + "way_" + str(args.training_samples_per_class) + "shot.pkl")) torch.save( relation_network.state_dict(), str("./models/omniglot_relation_network_" + str(args.class_num) + "way_" + str(args.training_samples_per_class) + "shot.pkl")) print("save networks for episode:", episode) last_accuracy = test_accuracy
def main(): # * Step 1: init data folders print("init data folders") # * init character folders for dataset construction metartrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( ) # * Step 2: init neural networks print("init neural networks") feature_encoder = ot.CNNEncoder().to(device) relation_network = ot.RelationNetwork(FEATURE_DIM, RELATION_DIM).to(device) feature_encoder.eval() relation_network.eval() if os.path.exists( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") total_accuracy = 0.0 for episode in range(EPISODE): # * test print("Testing...") total_rewards = 0 accuracies = [] for i in range(TEST_EPISODE): degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = next(iter(sample_dataloader)) test_images, test_labels = next(iter(test_dataloader)) sample_images, sample_labels = sample_images.to( device), sample_labels.to(device) test_images, test_labels = test_images.to(device), test_labels.to( device) # * Calculate features sample_features = feature_encoder(sample_images) sample_features = sample_features.view(CLASS_NUM, SAMPLE_NUM_PER_CLASS, FEATURE_DIM, 5, 5) sample_features = torch.sum(sample_features, 1).squeeze(1) test_features = feature_encoder(test_images) sample_features_ext = sample_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) relations = relation_network(relation_pairs).view(-1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM * SAMPLE_NUM_PER_CLASS) ] total_rewards += np.sum(rewards) accuracy = np.sum(rewards) / (1.0 * CLASS_NUM * SAMPLE_NUM_PER_CLASS) accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print(f'test accuracy : {test_accuracy}, h : {h}') total_accuracy += test_accuracy print(f"average accuracy : {total_accuracy / EPISODE}")
def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders,metaquery_character_folders = tg.omniglot_character_folders() # Step 2: init neural networks print("init neural networks") feature_encoder = models.FeatureEncoder().apply(weights_init).cuda(GPU) relation_network = models.SimilarityNetwork(FEATURE_DIM,RELATION_DIM).apply(weights_init).cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(),lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim,step_size=50000,gamma=0.1) relation_network_optim = torch.optim.Adam(relation_network.parameters(),lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim,step_size=50000,gamma=0.1) if os.path.exists(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): feature_encoder.load_state_dict(torch.load(str(METHOD + "/omniglot_feature_encoder_" + str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) print("load feature encoder success") if os.path.exists(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl")): relation_network.load_state_dict(torch.load(str(METHOD + "/omniglot_similarity_network_"+ str(CLASS_NUM) +"way_" + str(SUPPORT_NUM_PER_CLASS) +"shot.pkl"))) print("load similarity network success") if os.path.exists(METHOD) == False: os.system('mkdir ' + METHOD) # Step 3: build graph print("Training...") best_accuracy = 0.0 best_h = 0.0 for episode in range(EPISODE): with torch.no_grad(): # query print("Testing...") total_rewards = 0 for i in range(TEST_EPISODE): degrees = random.choice([0,90,180,270]) task = tg.OmniglotTask(metaquery_character_folders,CLASS_NUM,SUPPORT_NUM_PER_CLASS,TEST_NUM_PER_CLASS,) support_dataloader = tg.get_data_loader(task,num_per_class=SUPPORT_NUM_PER_CLASS,split="train",shuffle=False,rotation=degrees) query_dataloader = tg.get_data_loader(task,num_per_class=TEST_NUM_PER_CLASS,split="query",shuffle=True,rotation=degrees) support_images,support_labels = support_dataloader.__iter__().next() query_images,query_labels = query_dataloader.__iter__().next() # calculate features support_features = feature_encoder(Variable(support_images).cuda(GPU)) # 5x64 support_features = support_features.view(CLASS_NUM,SUPPORT_NUM_PER_CLASS,FEATURE_DIM,25).sum(1) query_features = feature_encoder(Variable(query_images).cuda(GPU)).view(TEST_NUM_PER_CLASS*CLASS_NUM,64,25) H_support_features = Variable(torch.Tensor(CLASS_NUM, 1, 64, 64)).cuda(GPU) H_query_features = Variable(torch.Tensor(TEST_NUM_PER_CLASS*CLASS_NUM, 1, 64, 64)).cuda(GPU) # HOP features for d in range(support_features.size(0)): s = support_features[d,:,:].squeeze(0) s = (1.0 / support_features.size(2)) * s.mm(s.t()) H_support_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) for d in range(query_features.size(0)): s = query_features[d,:,:].squeeze(0) s = (1.0 / query_features.size(2)) * s.mm(s.t()) H_query_features[d,:,:,:] = power_norm(s / s.trace(), SIGMA) # calculate relations # each query support link to every supports to calculate relations # to form a 100x128 matrix for relation network support_features_ext = H_support_features.unsqueeze(0).repeat(TEST_NUM_PER_CLASS*CLASS_NUM,1,1,1,1) query_features_ext = H_query_features.unsqueeze(0).repeat(CLASS_NUM,1,1,1,1) query_features_ext = torch.transpose(query_features_ext,0,1) relation_pairs = torch.cat((support_features_ext,query_features_ext),2).view(-1,2,64,64) relations = relation_network(relation_pairs).view(-1,CLASS_NUM) _,predict_labels = torch.max(relations.data,1) rewards = [1 if predict_labels[j]==query_labels[j].cuda(GPU) else 0 for j in range(CLASS_NUM*TEST_NUM_PER_CLASS)] total_rewards += np.sum(rewards) test_accuracy = total_rewards/1.0/CLASS_NUM/TEST_NUM_PER_CLASS/TEST_EPISODE print("query accuracy:",test_accuracy) print("best accuracy:",best_accuracy) if test_accuracy > best_accuracy: best_accuracy = test_accuracy
def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( ) # Step 2: init neural networks print("init neural networks") feature_encoder = CNNEncoder() relation_network = RelationNetwork(RELATION1_DIM, RELATION2_DIM, RELATION3_DIM) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) if os.path.exists( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") total_accuracy = 0.0 for episode in range(EPISODE): # test print("Testing...") total_rewards = 0 accuracies = [] for i in range(TEST_EPISODE): degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask(metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = sample_dataloader.__iter__().next() test_images, test_labels = test_dataloader.__iter__().next() # 注意在这里的test_images取了5张 # calculate features sample_features = feature_encoder(Variable(sample_images)) # print('sample_features :', sample_features.size()) test_features = feature_encoder(Variable(test_images)) # print('test_features :', test_features.size()) # calculate relations sample_features_ext = sample_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1) # print('sample_features_ext :', sample_features_ext.size()) test_features_ext = test_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1) # print('test_features_ext:', test_features_ext.size()) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.abs( (sample_features_ext - test_features_ext)).view(-1, 1600) # print('relation_pairs :', relation_pairs.size()) relations = relation_network(relation_pairs).view(-1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) # print(predict_labels) test_labels = test_labels.long() rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM) ] total_rewards += np.sum(rewards) accuracy = np.sum(rewards) / 1.0 / CLASS_NUM / SAMPLE_NUM_PER_CLASS accuracies.append(accuracy) test_accuracy, h = mean_confidence_interval(accuracies) print("test accuracy:", test_accuracy, "h:", h) total_accuracy += test_accuracy print("aver_accuracy:", total_accuracy / EPISODE)
def main(): # Step 1: init data folders print("init data folders") metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( ) # Step 2: init neural networks print("init neural networks") if USE_INCEPTION_EMBEDDING: feature_encoder = Inception(1, 10) else: feature_encoder = CNNEncoder() relation_network = RelationNetwork(RELATION1_DIM, RELATION2_DIM, RELATION3_DIM) # 运用apply()函数进行权重初始化 feature_encoder.apply(weights_init) relation_network.apply(weights_init) # feature_encoder.cuda(GPU) # relation_network.cuda(GPU) """要构建一个优化器optimizer,你必须给它一个可进行迭代优化的包含了所有参数(所有的参数必须是变量s)的列表。 然后,您可以指定程序优化特定的选项,例如学习速率,权重衰减等。然后一般还会定义学习率的变化策略, 这里采用的是torch.optim.lr_scheduler模块的StepLR类,表示每隔step_size个epoch就将学习率降为原来的gamma倍。""" feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) if os.path.exists( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") # Step 3: build graph print("Training...") last_accuracy = 0.0 for episode in range(EPISODE): # 训练开始的时候需要先更新下学习率,这是因为我们前面制定了学习率的变化策略,所以在每个epoch开始时都要更新下 feature_encoder_scheduler.step(episode) relation_network_scheduler.step(episode) # init dataset # sample_dataloader is to obtain previous samples for compare # batch_dataloader is to batch samples for training degrees = random.choice([0, 90, 180, 270]) # 制作支持集和目标集 task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) batch_dataloader = tg.get_data_loader( task, num_per_class=BATCH_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) # sample datas samples, sample_labels = sample_dataloader.__iter__().next() batches, batch_labels = batch_dataloader.__iter__().next() # samples.size: torch.Size([5, 1, 28, 28]);sample_labels.size: torch.Size([5]) # batches.size: torch.Size([95, 1, 28, 28]);batches_labels.size: torch.Size([95]) # print(batch_labels.view(-1, 1)) # 提取特征 # sample_features = feature_encoder(Variable(samples).cuda(GPU)) # 5x64*5*5 # batch_features = feature_encoder(Variable(batches).cuda(GPU)) # 20x64*5*5 sample_features = feature_encoder(Variable(samples)) # sample_features: torch.Size([5, 64, 5, 5]) batch_features = feature_encoder(Variable(batches)) # batch_features: torch.Size([95, 64, 5, 5]) # 拼接向量,其中torch.squeeze() 这个函数主要对数据的维度进行压缩,去掉维数为1的的维度 sample_features_ext = sample_features.unsqueeze(0).repeat( BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1) # sample_features_ext : torch.Size([95, 5, 64, 5, 5]) batch_features_ext = batch_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1) # batch_features_ext: torch.Size([5, 95, 64, 5, 5]) batch_features_ext = torch.transpose(batch_features_ext, 0, 1) # batch_features_ext after: torch.Size([95, 5, 64, 5, 5]) # 在深度学习处理图像的时候,经常要考虑将多张不同图片输入到网络,这时需要用torch.cat([image1,image2],1), '''relation_pairs = torch.cat((sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM*2, 5, 5)''' relation_pairs = torch.abs( (sample_features_ext - batch_features_ext)).view(-1, 1600) # 度量学习 relations = relation_network(relation_pairs).view(-1, CLASS_NUM) # relations torch.Size([95, 5]) # 优化目标 # mse = nn.MSELoss().cuda(GPU) mse = nn.MSELoss() # one_hot_labels = Variable(torch.zeros(BATCH_NUM_PER_CLASS*CLASS_NUM, # CLASS_NUM).scatter_(1, batch_labels.view(-1, 1), 1)).cuda(GPU) change = batch_labels.view(-1, 1).long() one_hot_labels = Variable( torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, change, 1)) loss = mse(relations, one_hot_labels) # training 然后先将网络中的所有梯度置0 feature_encoder.zero_grad() relation_network.zero_grad() loss.backward() # 计算得到loss后就要回传损失 # 梯度剪裁 torch.nn.utils.clip_grad_norm(feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm(relation_network.parameters(), 0.5) # 回传损失过程中会计算梯度,然后需要根据这些梯度更新参数,XX.step()就是用来更新参数的。之后, # 你就可以从xx.param_groups[0][‘params’]里面看到各个层的梯度和权值信息。 feature_encoder_optim.step() relation_network_optim.step() if (episode + 1) % 10 == 0: print("episode:", episode + 1, "loss", loss.data[0]) if (episode + 1) % 100 == 0: # test print("Testing...") total_rewards = 0 for i in range(TEST_EPISODE): degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask( metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS, ) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = sample_dataloader.__iter__( ).next() test_images, test_labels = test_dataloader.__iter__().next() test_labels = test_labels.long() # print('test_labels', test_labels) # calculate features # sample_features = feature_encoder(Variable(sample_images).cuda(GPU)) # 5x64 # test_features = feature_encoder(Variable(test_images).cuda(GPU)) # 20x64 sample_features = feature_encoder(Variable(sample_images)) test_features = feature_encoder(Variable(test_images)) # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network '''sample_features_ext = sample_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat(SAMPLE_NUM_PER_CLASS*CLASS_NUM, 1, 1, 1, 1)''' sample_features_ext = sample_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) '''relation_pairs = torch.cat((sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM*2, 5, 5)''' relation_pairs = torch.abs( (sample_features_ext - test_features_ext)).view(-1, 1600) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) test_labels = test_labels.long() rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM) ] # print('rewards', rewards) total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE print("test accuracy:", test_accuracy) if test_accuracy > last_accuracy: # save networks torch.save( feature_encoder.state_dict(), str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) torch.save( relation_network.state_dict(), str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) print("save networks for episode:", episode) last_accuracy = test_accuracy
def main(): # Step 1: init data folders print("init data folders") # init character folders for dataset construction metatrain_character_folders, metatest_character_folders = tg.omniglot_character_folders( ) # 获取训练的文件夹和测试文件夹,每一个文件夹包含一种数据 # Step 2: init neural networks print("init neural networks") feature_encoder = CNNEncoder() relation_network = RelationNetwork(FEATURE_DIM, RELATION_DIM) # 64 8 feature_encoder.apply(weights_init) relation_network.apply(weights_init) feature_encoder.cuda(GPU) relation_network.cuda(GPU) feature_encoder_optim = torch.optim.Adam(feature_encoder.parameters(), lr=LEARNING_RATE) feature_encoder_scheduler = StepLR(feature_encoder_optim, step_size=100000, gamma=0.5) relation_network_optim = torch.optim.Adam(relation_network.parameters(), lr=LEARNING_RATE) relation_network_scheduler = StepLR(relation_network_optim, step_size=100000, gamma=0.5) if os.path.exists( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): feature_encoder.load_state_dict( torch.load( str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load feature encoder success") if os.path.exists( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")): relation_network.load_state_dict( torch.load( str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl"))) print("load relation network success") # Step 3: build graph print("Training...") last_accuracy = 0.0 for episode in range(EPISODE): feature_encoder_scheduler.step(episode) relation_network_scheduler.step(episode) degrees = random.choice([0, 90, 180, 270]) # 1200个训练种类的文件夹list , 种类个数C=5,样本集每种种类的样本数 K=1,每种种类查询集中的个数 19 每训练一轮生成一个task task = tg.OmniglotTask(metatrain_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, BATCH_NUM_PER_CLASS) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) batch_dataloader = tg.get_data_loader( task, num_per_class=BATCH_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) # sample datas samples, sample_labels = sample_dataloader.__iter__().next( ) # [5,1,28,28] batches, batch_labels = batch_dataloader.__iter__().next( ) # [95,1,28,28] # calculate features sample_features = feature_encoder( Variable(samples).cuda(GPU)) # [5,64,5,5] batch_features = feature_encoder( Variable(batches).cuda(GPU)) # [95,64,5,5] # calculate relations # each batch sample link to every samples to calculate relations # to form a 100x128 matrix for relation network sample_features_ext = sample_features.unsqueeze(0).repeat( BATCH_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) # [95, 5, 64, 5, 5] batch_features_ext = batch_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) # [5, 95, 64, 5, 5] batch_features_ext = torch.transpose(batch_features_ext, 0, 1) # [95, 5, 64, 5, 5] relation_pairs = torch.cat( (sample_features_ext, batch_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) # 深度方向上连接 [95, 5, 128, 5, 5] -> [475, 128, 5, 5] relations = relation_network(relation_pairs).view( -1, CLASS_NUM) # [95,5] 95个Q样例,每个输出5个置信度值 mse = nn.MSELoss().cuda(GPU) # one_hot_labels 和 relations进行MSE运算 one_hot_labels = Variable( torch.zeros(BATCH_NUM_PER_CLASS * CLASS_NUM, CLASS_NUM).scatter_(1, batch_labels.view(-1, 1).long(), 1)).cuda(GPU) loss = mse(relations, one_hot_labels) # training feature_encoder.zero_grad() relation_network.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(feature_encoder.parameters(), 0.5) torch.nn.utils.clip_grad_norm_(relation_network.parameters(), 0.5) feature_encoder_optim.step() relation_network_optim.step() if (episode + 1) % 100 == 0: print("episode:", episode + 1, "loss", loss.data) if (episode + 1) % 5000 == 0: # test print("Testing...") total_rewards = 0 for i in range(TEST_EPISODE): degrees = random.choice([0, 90, 180, 270]) task = tg.OmniglotTask( metatest_character_folders, CLASS_NUM, SAMPLE_NUM_PER_CLASS, SAMPLE_NUM_PER_CLASS, ) sample_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="train", shuffle=False, rotation=degrees) test_dataloader = tg.get_data_loader( task, num_per_class=SAMPLE_NUM_PER_CLASS, split="test", shuffle=True, rotation=degrees) sample_images, sample_labels = sample_dataloader.__iter__( ).next() # [5, 1, 28, 28] test_images, test_labels = test_dataloader.__iter__().next( ) # [5, 1, 28, 28] 选取5张作为Q验证 # calculate features sample_features = feature_encoder( Variable(sample_images).cuda(GPU)) # [5, 64, 5, 5] test_features = feature_encoder( Variable(test_images).cuda(GPU)) # [5, 64, 5, 5] sample_features_ext = sample_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) test_features_ext = test_features.unsqueeze(0).repeat( SAMPLE_NUM_PER_CLASS * CLASS_NUM, 1, 1, 1, 1) test_features_ext = torch.transpose(test_features_ext, 0, 1) relation_pairs = torch.cat( (sample_features_ext, test_features_ext), 2).view(-1, FEATURE_DIM * 2, 5, 5) relations = relation_network(relation_pairs).view( -1, CLASS_NUM) _, predict_labels = torch.max(relations.data, 1) # 每行最大值 rewards = [ 1 if predict_labels[j] == test_labels[j] else 0 for j in range(CLASS_NUM) ] total_rewards += np.sum(rewards) test_accuracy = total_rewards / 1.0 / CLASS_NUM / TEST_EPISODE print("test accuracy:", test_accuracy) if test_accuracy > last_accuracy: # save networks torch.save( feature_encoder.state_dict(), str("./models/omniglot_feature_encoder_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) torch.save( relation_network.state_dict(), str("./models/omniglot_relation_network_" + str(CLASS_NUM) + "way_" + str(SAMPLE_NUM_PER_CLASS) + "shot.pkl")) print("save networks for episode:", episode) last_accuracy = test_accuracy