class Agent(): def __init__(self, args, action_space): self.action_space = action_space self.n = args.multi_step self.discount = args.discount self.target_update = args.target_update self.categorical = args.categorical self.noisy_linear = args.noisy_linear self.double_q = args.double_q self.max_grad_norm = args.max_grad_norm self.device = torch.device('cuda', args.gpu) self.num_param_updates = 0 if args.categorical: self.atoms = args.atoms self.v_min = args.v_min self.v_max = args.v_max self.support = torch.linspace( self.v_min, args.v_max, self.atoms).to(device=self.device) # Support (range) of z self.delta_z = (args.v_max - self.v_min) / (self.atoms - 1) self.online_net = DQN(args, self.action_space.n).to(device=self.device) if args.model and os.path.isfile(args.model): # Always load tensors onto CPU by default, will shift to GPU if necessary self.online_net.load_state_dict( torch.load(args.model, map_location='cpu')) self.online_net.train() self.target_net = DQN(args, self.action_space.n).to(device=self.device) self.update_target_net() self.target_net.eval() for param in self.target_net.parameters(): param.requires_grad = False self.optimizer = optim.Adam(self.online_net.parameters(), lr=args.lr, eps=args.adam_eps, amsgrad=True) [self.online_net, self.target_net ], self.optimizer = amp.initialize([self.online_net, self.target_net], self.optimizer, opt_level=args.opt_level, loss_scale=args.loss_scale) if args.distributed: self.online_net = DDP(self.online_net, delay_allreduce=True) self.target_net = DDP(self.target_net, delay_allreduce=True) # Resets noisy weights in all linear layers (of online net only) def reset_noise(self): if isinstance(self.online_net, DQN): self.online_net.reset_noise() else: self.online_net.module.reset_noise() # Acts based on single state (no batch) def act(self, state): with torch.no_grad(): probs = self.online_net(state.to(self.device)) if self.categorical: probs = self.support.expand_as(probs) * probs actions = probs.sum(-1).argmax(-1).to(state.device) return actions # Acts with an ε-greedy policy (used for evaluation only) def act_e_greedy( self, state, epsilon=0.001): # High ε can reduce evaluation scores drastically actions = self.act(state) mask = torch.rand( state.size(0), device=state.device, dtype=torch.float32) < epsilon masked = mask.sum().item() if masked > 0: actions[mask] = torch.randint(0, self.action_space.n, (masked, ), device=state.device, dtype=torch.long) return actions def learn(self, states, actions, returns, next_states, nonterminals, weights): tactions = actions.unsqueeze(-1).unsqueeze(-1) if self.categorical: tactions = tactions.expand(-1, -1, self.atoms) # Calculate current state probabilities (online network noise already sampled) nvtx.range_push('agent:online (state) probs') ps = self.online_net( states, log=True) # Log probabilities log p(s_t, ·; θonline) ps_a = ps.gather(1, tactions) # log p(s_t, a_t; θonline) nvtx.range_pop() with torch.no_grad(): if isinstance(self.target_net, DQN): self.target_net.reset_noise() else: self.target_net.module.reset_noise( ) # Sample new target net noise nvtx.range_push('agent:target (next state) probs') tns = self.target_net( next_states) # Probabilities p(s_t+n, ·; θtarget) nvtx.range_pop() if self.double_q: # Calculate nth next state probabilities nvtx.range_push('agent:online (next state) probs') pns = self.online_net( next_states) # Probabilities p(s_t+n, ·; θonline) nvtx.range_pop() else: pns = tns if self.categorical: pns = self.support.expand_as( pns ) * pns # Distribution d_t+n = (z, p(s_t+n, ·; θonline)) # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))] argmax_indices_ns = pns.sum(-1).argmax(-1).unsqueeze(-1).unsqueeze( -1) if self.categorical: argmax_indices_ns = argmax_indices_ns.expand( -1, -1, self.atoms) pns_a = tns.gather( 1, argmax_indices_ns ) # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget) if self.categorical: # Compute Tz (Bellman operator T applied to z) # Tz = R^n + (γ^n)z (accounting for terminal states) Tz = returns.unsqueeze(-1) + nonterminals.float().unsqueeze( -1) * (self.discount**self.n) * self.support.unsqueeze(0) Tz = Tz.clamp(min=self.v_min, max=self.v_max) # Clamp between supported values # Compute L2 projection of Tz onto fixed support z b = (Tz - self.v_min) / self.delta_z # b = (Tz - Vmin) / Δz l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64) # Fix disappearing probability mass when l = b = u (b is int) l[(u > 0) * (l == u)] -= 1 u[(l < (self.atoms - 1)) * (l == u)] += 1 # Distribute probability of Tz batch_size = states.size(0) m = states.new_zeros(batch_size, self.atoms) offset = torch.linspace(0, ((batch_size - 1) * self.atoms), batch_size).unsqueeze(1).expand( batch_size, self.atoms).to(actions) m.view(-1).index_add_( 0, (l + offset).view(-1), (pns_a.squeeze(1) * (u.float() - b) ).view(-1)) # m_l = m_l + p(s_t+n, a*)(u - b) m.view(-1).index_add_( 0, (u + offset).view(-1), (pns_a.squeeze(1) * (b - l.float()) ).view(-1)) # m_u = m_u + p(s_t+n, a*)(b - l) else: Tz = returns + nonterminals.float() * ( self.discount**self.n) * pns_a.squeeze(-1).squeeze(-1) if self.categorical: loss = -torch.sum( m * ps_a.squeeze(1), 1) # Cross-entropy loss (minimises DKL(m||p(s_t, a_t))) weights = weights.unsqueeze(-1) else: loss = F.mse_loss(ps_a.squeeze(-1).squeeze(-1), Tz, reduction='none') nvtx.range_push('agent:loss + step') self.optimizer.zero_grad() weighted_loss = (weights * loss).mean() with amp.scale_loss(weighted_loss, self.optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.max_grad_norm) self.optimizer.step() nvtx.range_pop() return loss.detach() def update_target_net(self): self.target_net.load_state_dict(self.online_net.state_dict()) # Save model parameters on current device (don't move model between devices) def save(self, path): torch.save(self.online_net.state_dict(), os.path.join(path, 'model.pth')) # Evaluates Q-value based on single state (no batch) def evaluate_q(self, state): with torch.no_grad(): q = self.online_net(state.unsqueeze(0).to(self.device)) if self.categorical: q *= self.support return q.sum(-1).max(-1)[0].item() def train(self): self.online_net.train() def eval(self): self.online_net.eval() def __str__(self): return self.online_net.__str__()
def main(): global args setup_default_logging() args, args_text = _parse_args() args.prefetcher = not args.no_prefetcher args.distributed = False if 'WORLD_SIZE' in os.environ: args.distributed = int(os.environ['WORLD_SIZE']) > 1 if args.distributed and args.num_gpu > 1: logging.warning( 'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.' ) args.num_gpu = 1 args.device = 'cuda:0' args.world_size = 1 args.rank = 0 # global rank if args.distributed: args.num_gpu = 1 args.device = 'cuda:%d' % args.local_rank torch.cuda.set_device(args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') args.world_size = torch.distributed.get_world_size() args.rank = torch.distributed.get_rank() assert args.rank >= 0 if args.distributed: logging.info( 'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.' % (args.rank, args.world_size)) else: logging.info('Training with a single process on %d GPUs.' % args.num_gpu) torch.manual_seed(args.seed + args.rank) np.random.seed(args.seed + args.rank) model = create_model(args.model, pretrained=args.pretrained, num_classes=args.num_classes, drop_rate=args.drop, global_pool=args.gp, bn_tf=args.bn_tf, bn_momentum=args.bn_momentum, bn_eps=args.bn_eps, checkpoint_path=args.initial_checkpoint) if args.binarizable: Model_binary_patch(model) if args.local_rank == 0: logging.info('Model %s created, param count: %d' % (args.model, sum([m.numel() for m in model.parameters()]))) data_config = resolve_data_config(vars(args), model=model, verbose=args.local_rank == 0) if args.num_gpu > 1: if args.amp: logging.warning( 'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.' ) args.amp = False model = nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() else: model.cuda() optimizer = create_optimizer(args, model) use_amp = False if has_apex and args.amp: print('Using amp with --opt-level {}.'.format(args.opt_level)) model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) use_amp = True else: print('Do NOT use amp.') if args.local_rank == 0: logging.info('NVIDIA APEX {}. AMP {}.'.format( 'installed' if has_apex else 'not installed', 'on' if use_amp else 'off')) # optionally resume from a checkpoint resume_state = {} resume_epoch = None if args.resume: resume_state, resume_epoch = resume_checkpoint(model, args.resume) if resume_state and not args.no_resume_opt: if 'optimizer' in resume_state: if args.local_rank == 0: logging.info('Restoring Optimizer state from checkpoint') optimizer.load_state_dict(resume_state['optimizer']) if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__: if args.local_rank == 0: logging.info('Restoring NVIDIA AMP state from checkpoint') amp.load_state_dict(resume_state['amp']) resume_state = None if args.freeze_binary: Model_freeze_binary(model) if args.distributed: if args.sync_bn: try: if has_apex: model = convert_syncbn_model(model) else: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm( model) if args.local_rank == 0: logging.info( 'Converted model to use Synchronized BatchNorm.') except Exception as e: logging.error( 'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1' ) if has_apex: model = DDP(model, delay_allreduce=True) else: if args.local_rank == 0: logging.info( "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP." ) model = DDP(model, device_ids=[args.local_rank ]) # can use device str in Torch >= 1.1 # NOTE: EMA model does not need to be wrapped by DDP lr_scheduler, num_epochs = create_scheduler(args, optimizer) print(num_epochs) # start_epoch = 0 # if args.start_epoch is not None: # a specified start_epoch will always override the resume epoch start_epoch = args.start_epoch elif resume_epoch is not None: start_epoch = resume_epoch if args.reset_lr_scheduler is not None: lr_scheduler.base_values = len( lr_scheduler.base_values) * [args.reset_lr_scheduler] lr_scheduler.step(start_epoch) if lr_scheduler is not None and start_epoch > 0: lr_scheduler.step(start_epoch) if args.local_rank == 0: logging.info('Scheduled epochs: {}'.format(num_epochs)) # Using pruner to get sparse weights if args.prune: pruner = Pruner_mixed(model, 0, 100, args.pruner) else: pruner = None dataset_train = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=True, download=True) collate_fn = None if args.prefetcher and args.mixup > 0: collate_fn = FastCollateMixup(args.mixup, args.smoothing, args.num_classes) loader_train = create_loader_CIFAR100( dataset_train, input_size=data_config['input_size'], batch_size=args.batch_size, is_training=True, use_prefetcher=args.prefetcher, rand_erase_prob=args.reprob, rand_erase_mode=args.remode, rand_erase_count=args.recount, color_jitter=args.color_jitter, auto_augment=args.aa, interpolation='random', mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, collate_fn=collate_fn, is_clean_data=args.clean_train, ) dataset_eval = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100', train=False, download=True) loader_eval = create_loader_CIFAR100( dataset_eval, input_size=data_config['input_size'], batch_size=4 * args.batch_size, is_training=False, use_prefetcher=args.prefetcher, interpolation=data_config['interpolation'], mean=data_config['mean'], std=data_config['std'], num_workers=args.workers, distributed=args.distributed, ) if args.mixup > 0.: # smoothing is handled with mixup label transform train_loss_fn = SoftTargetCrossEntropy( multiplier=args.softmax_multiplier).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() elif args.smoothing: train_loss_fn = LabelSmoothingCrossEntropy( smoothing=args.smoothing).cuda() validate_loss_fn = nn.CrossEntropyLoss().cuda() else: train_loss_fn = nn.CrossEntropyLoss().cuda() validate_loss_fn = train_loss_fn eval_metric = args.eval_metric best_metric = None best_epoch = None saver = None saver_last_10_epochs = None output_dir = '' if args.local_rank == 0: output_base = args.output if args.output else './output' exp_name = '-'.join([ datetime.now().strftime("%Y%m%d-%H%M%S"), args.model, str(data_config['input_size'][-1]) ]) output_dir = get_outdir(output_base, 'train', exp_name) decreasing = True if eval_metric == 'loss' else False os.makedirs(output_dir + '/Top') os.makedirs(output_dir + '/Last') saver = CheckpointSaver( checkpoint_dir=output_dir + '/Top', decreasing=decreasing, max_history=10) # Save the results of the top 10 epochs saver_last_10_epochs = CheckpointSaver( checkpoint_dir=output_dir + '/Last', decreasing=decreasing, max_history=10) # Save the results of the last 10 epochs with open(os.path.join(output_dir, 'args.yaml'), 'w') as f: f.write(args_text) f.write('==============================\n') f.write(model.__str__()) # if pruner: # f.write('\n Sparsity \n') # #f.write(pruner.threshold_dict.__str__()) # f.write('\n pruner.start_epoch={}, pruner.end_epoch={}'.format(pruner.start_epoch, pruner.end_epoch)) tensorboard_writer = SummaryWriter(output_dir) try: for epoch in range(start_epoch, num_epochs): global alpha alpha = get_alpha(epoch, args) if args.distributed: loader_train.sampler.set_epoch(epoch) if pruner: pruner.on_epoch_begin(epoch) # pruning train_metrics = train_epoch(epoch, model, loader_train, optimizer, train_loss_fn, args, lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, use_amp=use_amp, tensorboard_writer=tensorboard_writer, pruner=pruner) if pruner: pruner.print_statistics() eval_metrics = validate(model, loader_eval, validate_loss_fn, args, tensorboard_writer=tensorboard_writer, epoch=epoch) if lr_scheduler is not None: # step LR for next epoch lr_scheduler.step(epoch + 1, eval_metrics[eval_metric]) update_summary(epoch, train_metrics, eval_metrics, os.path.join(output_dir, 'summary.csv'), write_header=best_metric is None) if saver is not None: # save proper checkpoint with eval metric save_metric = eval_metrics[eval_metric] best_metric, best_epoch = saver.save_checkpoint( model, optimizer, args, epoch=epoch, metric=save_metric, use_amp=use_amp) if saver_last_10_epochs is not None: # save the checkpoint in the last 5 epochs _, _ = saver_last_10_epochs.save_checkpoint(model, optimizer, args, epoch=epoch, metric=epoch, use_amp=use_amp) except KeyboardInterrupt: pass if best_metric is not None: logging.info('*** Best metric: {0} (epoch {1})'.format( best_metric, best_epoch))