if args.scheduler: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[60, 120, 160], gamma=0.2, last_epoch=args.begin_epoch - 1 ) else: raise ValueError('Unknown optimizer {}'.format(args.optimizer)) best_test_bpd = math.inf if (args.resume is not None): logger.info('Resuming model from {}'.format(args.resume)) with torch.no_grad(): x = torch.rand(1, *input_size[1:]).to(device) model(x) checkpt = torch.load(args.resume) sd = {k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k} state = model.state_dict() state.update(sd) model.load_state_dict(state, strict=True) ema.set(checkpt['ema']) if 'optimizer_state_dict' in checkpt: optimizer.load_state_dict(checkpt['optimizer_state_dict']) # Manually move optimizer state to GPU for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) del checkpt del state logger.info(optimizer)
gamma=0.2, last_epoch=args.begin_epoch - 1) else: raise ValueError('Unknown optimizer {}'.format(args.optimizer)) if (args.resume is not None): logger.info('Resuming model from {}'.format(args.resume)) with torch.no_grad(): x = torch.rand(1, *input_size[1:]).to(device) model2(x) checkpt = torch.load(args.resume) sd = { k: v for k, v in checkpt['state_dict'].items() if 'last_n_samples' not in k } state = model2.state_dict() state.update(sd) model2.load_state_dict(state, strict=True) ema.set(checkpt['ema']) if 'optimizer_state_dict' in checkpt: optimizer.load_state_dict(checkpt['optimizer_state_dict']) # Manually move optimizer state to GPU for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = v.to(device) del checkpt del state best_test_bpd = math.inf if (args.resume is not None):