val_sampler = CategoriesSampler(valset.label, 500, args.validation_way, args.shot + args.query) val_loader = DataLoader(dataset=valset, batch_sampler=val_sampler, num_workers=8, pin_memory=True) model = None if args.model_type.lower() == 'protonet': model = ProtoNet(args) elif args.model_type.lower() == 'hypnet': model = HypNet(args) elif args.model_type.lower() == 'protonetwithhyperbolic': model = ProtoNetWithHyperbolic(args) optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) if args.lr_decay: lr_scheduler = torch.optim.lr_scheduler.StepLR( optimizer, step_size=args.step_size, gamma=args.gamma) # load pre-trained model (no FC weights) model_dict = model.state_dict() if args.init_weights is not None: pretrained_dict = torch.load(args.init_weights)['params'] # remove weights for FC pretrained_dict = { 'encoder.' + k: v for k, v in pretrained_dict.items() } pretrained_dict = {
def main(): global args, best_acc1, device # Init seed np.random.seed(args.manual_seed) torch.manual_seed(args.manual_seed) torch.cuda.manual_seed(args.manual_seed) if args.dataset == 'omniglot': train_loader, val_loader = get_dataloader(args, 'trainval', 'test') input_dim = 1 else: train_loader, val_loader = get_dataloader(args, 'train', 'val') input_dim = 3 if args.model == 'protonet': model = ProtoNet(input_dim).to(device) print("ProtoNet loaded") else: model = ResNet(input_dim).to(device) print("ResNet loaded") criterion = PrototypicalLoss().to(device) optimizer = torch.optim.Adam(model.parameters(), args.lr) cudnn.benchmark = True if args.resume: try: checkpoint = torch.load( sorted(glob(f'{args.log_dir}/checkpoint_*.pth'), key=len)[-1]) except Exception: checkpoint = torch.load(args.log_dir + '/model_best.pth') model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) start_epoch = checkpoint['epoch'] best_acc1 = checkpoint['best_acc1'] print(f"load checkpoint {args.exp_name}") else: start_epoch = 1 scheduler = torch.optim.lr_scheduler.StepLR( optimizer=optimizer, gamma=args.lr_scheduler_gamma, step_size=args.lr_scheduler_step) print( f"model parameter : {sum(p.numel() for p in model.parameters() if p.requires_grad)}" ) for epoch in range(start_epoch, args.epochs + 1): train_loss = train(train_loader, model, optimizer, criterion, epoch) is_test = False if epoch % args.test_iter else True if is_test or epoch == args.epochs or epoch == 1: val_loss, acc1 = validate(val_loader, model, criterion, epoch) if acc1 >= best_acc1: is_best = True best_acc1 = acc1 else: is_best = False save_checkpoint( { 'epoch': epoch, 'model_state_dict': model.state_dict(), 'best_acc1': best_acc1, 'optimizer_state_dict': optimizer.state_dict(), }, is_best, args) if is_best: writer.add_scalar("BestAcc", acc1, epoch) print( f"[{epoch}/{args.epochs}] {train_loss:.3f}, {val_loss:.3f}, {acc1:.3f}, # {best_acc1:.3f}" ) else: print(f"[{epoch}/{args.epochs}] {train_loss:.3f}") scheduler.step() writer.close()