def main(): global args args = parser.parse_args() os.environ[ "CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 on stackoverflow os.environ["CUDA_VISIBLE_DEVICES"] = str(args.GPU_ID) exp_name = args.name kwargs = {'num_workers': 4} # create model, use Learner to wrap it model = Learner(ConvNet()) model = model.cuda() cudnn.benchmark = True # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['iter'] prec = checkpoint['prec'] model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}' (iter {})".format( args.resume, checkpoint['iter'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.mode == 'train': train(model, exp_name, kwargs) else: evaluation(model, args.data_dir, args.batch_size, kwargs)
def test_resnet(): try: from torchvision.models.resnet import resnet18 from metann import Learner net = resnet18() net = Learner(net) print(net.functional(net.parameters(), True, torch.randn(3, 3, 224, 224))) except ImportError: Warning('torchvision not included, cannot be tested') return finally: return
def test_learner(): net = Learner( nn.Sequential( nn.Conv2d(3, 3, 3), nn.Conv2d(3, 3, 3), Flatten(), nn.Linear(3, 4), )).to(device) x = torch.randn(3, 3, 5, 5).to(device) y = torch.randint(0, 4, (3, )).to(device) criterion = nn.CrossEntropyLoss() params = list(net.parameters()) for i in range(500): outs = net.functional(params, True, x) loss = criterion(outs, y) grads = torch.autograd.grad(loss, params) with torch.no_grad(): params = [(a - 0.01 * b).requires_grad_() for a, b in zip(params, grads)] print(loss) assert loss <= 0.05
def __init__(self, config): super(CNN, self).__init__() self.module = Learner(get_cnn(config))
def run(rank, size, args): """ Distributed Synchronous SGD Example """ device = torch.device(args.device) config = [ ('conv2d', [3, 32, 3]), ('relu', [True]), ('bn2d', [32]), ('max_pool2d', [2, 2]), ('conv2d', [32, 32, 3]), ('relu', [True]), ('bn2d', [32]), ('max_pool2d', [2, 2]), ('conv2d', [32, 32, 3]), ('relu', [True]), ('bn2d', [32]), ('max_pool2d', [2, 2]), ('conv2d', [32, 32, 3]), ('relu', [True]), ('bn2d', [32]), ('max_pool2d', [2, 1]), ('flatten', ), ('linear', [32 * 5 * 5, 5]), ] train_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='train') # valid_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='validation') test_dataset = l2l.vision.datasets.MiniImagenet(root='./data', mode='test') # train_loader = task_loader(train_dataset, args.n_way, args.k_shot, args.k_query, 10000, # batch_size=args.task_num//args.world_size) # test_loader = task_loader(test_dataset, args.n_way, args.k_shot, args.k_query, 1024, # batch_size=args.task_num//args.world_size) net = get_cnn(config) #要改 model = Meta(update_lr=args.update_lr, meta_lr=args.meta_lr, update_step=args.update_step, update_step_test=args.update_step_test, learner=Learner(net)).to(device) average_model(model) optimizer = model.meta_optim tmp = filter(lambda x: x.requires_grad, model.parameters()) num = sum(map(lambda x: np.prod(x.shape), tmp)) print(model) print('Total trainable tensors:', num) # num_batches = ceil(len(train_set) / float(args.batch_size)) # for epoch in range(args.epoch): for epoch in range(args.epoch // 10000): epoch_loss = 0.0 average_model(model) train_loader = task_loader(train_dataset, args.n_way, args.k_shot, args.k_query, 10000, batch_size=args.task_num // args.world_size) for step, data in enumerate(train_loader): # data = tuple(map(lambda x: slc(to_device(relabel(x), device)), data)) data = [[x.to(device) for x in collate(a) + collate(b)] for a, b in data] optimizer.zero_grad() if step * args.task_num % 120 == 0: with model.logging: loss = model(data) accs = model.accs() print('\rRank ', dist.get_rank(), 'step:', step, '\ttraining acc:', accs) else: loss = model(data) loss.backward() average_gradients(model) optimizer.step() # if epoch % 5 == 0: # evaluation if step * args.task_num % 2000 == 0: accs_all_test = [] test_loader = task_loader(test_dataset, args.n_way, args.k_shot, args.k_query, 1024, batch_size=args.task_num // args.world_size) model.eval() for data_test in test_loader: data_test = [[ x.to(device) for x in collate(a) + collate(b) ] for a, b in data_test] with model.logging: # data_test = tuple(map(lambda x: slc(to_device(relabel(x), device)), data_test)) loss = model(data_test) loss.backward() # accs = model.accs() accs_all_test.append(model.log['corrects']) optimizer.zero_grad() # [b, update_step+1] accs = np.array(accs_all_test).mean(axis=0).astype(np.float16) print('Rank ', dist.get_rank(), ', epoch ', epoch, ': ', 'Test acc:', accs) optimizer.zero_grad() del data_test model.train()