torch.backends.cudnn.benchmark = False tt.arg.log_dir_user = tt.arg.log_dir if tt.arg.log_dir_user is None else tt.arg.log_dir_user tt.arg.log_dir = tt.arg.log_dir_user if not os.path.exists('asset/checkpoints'): os.makedirs('asset/checkpoints') if not os.path.exists('asset/checkpoints/' + tt.arg.experiment): os.makedirs('asset/checkpoints/' + tt.arg.experiment) enc_module = EmbeddingImagenet(emb_size=tt.arg.emb_size) gnn_module = GraphNetwork(in_features=tt.arg.emb_size, node_features=tt.arg.num_edge_features, edge_features=tt.arg.num_node_features, num_layers=tt.arg.num_layers, dropout=tt.arg.dropout) if tt.arg.dataset == 'mini': train_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='train') valid_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='val') elif tt.arg.dataset == 'tiered': train_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='train') valid_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='val') else: print('Unknown dataset!') data_loader = {'train': train_loader, 'val': valid_loader }
exp_name += '_N-{}_K-{}_U-{}'.format(tt.arg.num_ways, tt.arg.num_shots, tt.arg.num_unlabeled) exp_name += '_L-{}_B-{}'.format(tt.arg.num_layers, tt.arg.meta_batch_size) exp_name += '_C-{}'.format(tt.arg.num_cell) exp_name += '_T-{}_SEED-222'.format(tt.arg.transductive) if not exp_name == tt.arg.test_model: print(exp_name) print(tt.arg.test_model) print('Test model and input arguments are mismatched!') AssertionError() gnn_module = GraphNetwork(in_features=tt.arg.emb_size, node_features=tt.arg.num_edge_features, edge_features=tt.arg.num_node_features, num_layers=tt.arg.num_layers, num_cell=tt.arg.num_cell, dropout=tt.arg.dropout, arch=tt.arg.arch).cuda() if tt.arg.dataset == 'mini': test_loader = MiniImagenetLoader(root=tt.arg.dataset_root, partition='test') elif tt.arg.dataset == 'tiered': test_loader = TieredImagenetLoader(root=tt.arg.dataset_root, partition='test') else: print('Unknown dataset!') data_loader = {'test': test_loader}