コード例 #1
0
ファイル: agent.py プロジェクト: tevfikoguz/cule
class Agent():
    def __init__(self, args, action_space):
        self.action_space = action_space
        self.n = args.multi_step
        self.discount = args.discount
        self.target_update = args.target_update
        self.categorical = args.categorical
        self.noisy_linear = args.noisy_linear
        self.double_q = args.double_q
        self.max_grad_norm = args.max_grad_norm
        self.device = torch.device('cuda', args.gpu)
        self.num_param_updates = 0

        if args.categorical:
            self.atoms = args.atoms
            self.v_min = args.v_min
            self.v_max = args.v_max
            self.support = torch.linspace(
                self.v_min, args.v_max,
                self.atoms).to(device=self.device)  # Support (range) of z
            self.delta_z = (args.v_max - self.v_min) / (self.atoms - 1)

        self.online_net = DQN(args, self.action_space.n).to(device=self.device)

        if args.model and os.path.isfile(args.model):
            # Always load tensors onto CPU by default, will shift to GPU if necessary
            self.online_net.load_state_dict(
                torch.load(args.model, map_location='cpu'))
        self.online_net.train()

        self.target_net = DQN(args, self.action_space.n).to(device=self.device)
        self.update_target_net()
        self.target_net.eval()
        for param in self.target_net.parameters():
            param.requires_grad = False

        self.optimizer = optim.Adam(self.online_net.parameters(),
                                    lr=args.lr,
                                    eps=args.adam_eps,
                                    amsgrad=True)

        [self.online_net, self.target_net
         ], self.optimizer = amp.initialize([self.online_net, self.target_net],
                                            self.optimizer,
                                            opt_level=args.opt_level,
                                            loss_scale=args.loss_scale)
        if args.distributed:
            self.online_net = DDP(self.online_net, delay_allreduce=True)
            self.target_net = DDP(self.target_net, delay_allreduce=True)

    # Resets noisy weights in all linear layers (of online net only)
    def reset_noise(self):
        if isinstance(self.online_net, DQN):
            self.online_net.reset_noise()
        else:
            self.online_net.module.reset_noise()

    # Acts based on single state (no batch)
    def act(self, state):
        with torch.no_grad():
            probs = self.online_net(state.to(self.device))
            if self.categorical:
                probs = self.support.expand_as(probs) * probs
            actions = probs.sum(-1).argmax(-1).to(state.device)
            return actions

    # Acts with an ε-greedy policy (used for evaluation only)
    def act_e_greedy(
            self,
            state,
            epsilon=0.001):  # High ε can reduce evaluation scores drastically
        actions = self.act(state)

        mask = torch.rand(
            state.size(0), device=state.device, dtype=torch.float32) < epsilon
        masked = mask.sum().item()
        if masked > 0:
            actions[mask] = torch.randint(0,
                                          self.action_space.n, (masked, ),
                                          device=state.device,
                                          dtype=torch.long)

        return actions

    def learn(self, states, actions, returns, next_states, nonterminals,
              weights):

        tactions = actions.unsqueeze(-1).unsqueeze(-1)
        if self.categorical:
            tactions = tactions.expand(-1, -1, self.atoms)

        # Calculate current state probabilities (online network noise already sampled)
        nvtx.range_push('agent:online (state) probs')
        ps = self.online_net(
            states, log=True)  # Log probabilities log p(s_t, ·; θonline)
        ps_a = ps.gather(1, tactions)  # log p(s_t, a_t; θonline)
        nvtx.range_pop()

        with torch.no_grad():
            if isinstance(self.target_net, DQN):
                self.target_net.reset_noise()
            else:
                self.target_net.module.reset_noise(
                )  # Sample new target net noise

            nvtx.range_push('agent:target (next state) probs')
            tns = self.target_net(
                next_states)  # Probabilities p(s_t+n, ·; θtarget)
            nvtx.range_pop()

            if self.double_q:
                # Calculate nth next state probabilities
                nvtx.range_push('agent:online (next state) probs')
                pns = self.online_net(
                    next_states)  # Probabilities p(s_t+n, ·; θonline)
                nvtx.range_pop()
            else:
                pns = tns

            if self.categorical:
                pns = self.support.expand_as(
                    pns
                ) * pns  # Distribution d_t+n = (z, p(s_t+n, ·; θonline))

            # Perform argmax action selection using online network: argmax_a[(z, p(s_t+n, a; θonline))]
            argmax_indices_ns = pns.sum(-1).argmax(-1).unsqueeze(-1).unsqueeze(
                -1)
            if self.categorical:
                argmax_indices_ns = argmax_indices_ns.expand(
                    -1, -1, self.atoms)
            pns_a = tns.gather(
                1, argmax_indices_ns
            )  # Double-Q probabilities p(s_t+n, argmax_a[(z, p(s_t+n, a; θonline))]; θtarget)

            if self.categorical:
                # Compute Tz (Bellman operator T applied to z)
                # Tz = R^n + (γ^n)z (accounting for terminal states)
                Tz = returns.unsqueeze(-1) + nonterminals.float().unsqueeze(
                    -1) * (self.discount**self.n) * self.support.unsqueeze(0)
                Tz = Tz.clamp(min=self.v_min,
                              max=self.v_max)  # Clamp between supported values
                # Compute L2 projection of Tz onto fixed support z
                b = (Tz - self.v_min) / self.delta_z  # b = (Tz - Vmin) / Δz
                l, u = b.floor().to(torch.int64), b.ceil().to(torch.int64)
                # Fix disappearing probability mass when l = b = u (b is int)
                l[(u > 0) * (l == u)] -= 1
                u[(l < (self.atoms - 1)) * (l == u)] += 1

                # Distribute probability of Tz
                batch_size = states.size(0)
                m = states.new_zeros(batch_size, self.atoms)
                offset = torch.linspace(0, ((batch_size - 1) * self.atoms),
                                        batch_size).unsqueeze(1).expand(
                                            batch_size, self.atoms).to(actions)
                m.view(-1).index_add_(
                    0, (l + offset).view(-1),
                    (pns_a.squeeze(1) * (u.float() - b)
                     ).view(-1))  # m_l = m_l + p(s_t+n, a*)(u - b)
                m.view(-1).index_add_(
                    0, (u + offset).view(-1),
                    (pns_a.squeeze(1) * (b - l.float())
                     ).view(-1))  # m_u = m_u + p(s_t+n, a*)(b - l)
            else:
                Tz = returns + nonterminals.float() * (
                    self.discount**self.n) * pns_a.squeeze(-1).squeeze(-1)

        if self.categorical:
            loss = -torch.sum(
                m * ps_a.squeeze(1),
                1)  # Cross-entropy loss (minimises DKL(m||p(s_t, a_t)))
            weights = weights.unsqueeze(-1)
        else:
            loss = F.mse_loss(ps_a.squeeze(-1).squeeze(-1),
                              Tz,
                              reduction='none')

        nvtx.range_push('agent:loss + step')
        self.optimizer.zero_grad()
        weighted_loss = (weights * loss).mean()
        with amp.scale_loss(weighted_loss, self.optimizer) as scaled_loss:
            scaled_loss.backward()
        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                       self.max_grad_norm)
        self.optimizer.step()
        nvtx.range_pop()

        return loss.detach()

    def update_target_net(self):
        self.target_net.load_state_dict(self.online_net.state_dict())

    # Save model parameters on current device (don't move model between devices)
    def save(self, path):
        torch.save(self.online_net.state_dict(),
                   os.path.join(path, 'model.pth'))

    # Evaluates Q-value based on single state (no batch)
    def evaluate_q(self, state):
        with torch.no_grad():
            q = self.online_net(state.unsqueeze(0).to(self.device))
            if self.categorical:
                q *= self.support
            return q.sum(-1).max(-1)[0].item()

    def train(self):
        self.online_net.train()

    def eval(self):
        self.online_net.eval()

    def __str__(self):
        return self.online_net.__str__()
コード例 #2
0
def main():
    global args
    setup_default_logging()
    args, args_text = _parse_args()

    args.prefetcher = not args.no_prefetcher
    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1
        if args.distributed and args.num_gpu > 1:
            logging.warning(
                'Using more than one GPU per process in distributed mode is not allowed. Setting num_gpu to 1.'
            )
            args.num_gpu = 1

    args.device = 'cuda:0'
    args.world_size = 1
    args.rank = 0  # global rank
    if args.distributed:
        args.num_gpu = 1
        args.device = 'cuda:%d' % args.local_rank
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        args.world_size = torch.distributed.get_world_size()
        args.rank = torch.distributed.get_rank()
    assert args.rank >= 0

    if args.distributed:
        logging.info(
            'Training in distributed mode with multiple processes, 1 GPU per process. Process %d, total %d.'
            % (args.rank, args.world_size))
    else:
        logging.info('Training with a single process on %d GPUs.' %
                     args.num_gpu)

    torch.manual_seed(args.seed + args.rank)
    np.random.seed(args.seed + args.rank)

    model = create_model(args.model,
                         pretrained=args.pretrained,
                         num_classes=args.num_classes,
                         drop_rate=args.drop,
                         global_pool=args.gp,
                         bn_tf=args.bn_tf,
                         bn_momentum=args.bn_momentum,
                         bn_eps=args.bn_eps,
                         checkpoint_path=args.initial_checkpoint)

    if args.binarizable:
        Model_binary_patch(model)

    if args.local_rank == 0:
        logging.info('Model %s created, param count: %d' %
                     (args.model, sum([m.numel()
                                       for m in model.parameters()])))

    data_config = resolve_data_config(vars(args),
                                      model=model,
                                      verbose=args.local_rank == 0)

    if args.num_gpu > 1:
        if args.amp:
            logging.warning(
                'AMP does not work well with nn.DataParallel, disabling. Use distributed mode for multi-GPU AMP.'
            )
            args.amp = False
        model = nn.DataParallel(model,
                                device_ids=list(range(args.num_gpu))).cuda()

    else:
        model.cuda()

    optimizer = create_optimizer(args, model)

    use_amp = False
    if has_apex and args.amp:
        print('Using amp with --opt-level {}.'.format(args.opt_level))
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.opt_level)
        use_amp = True
    else:
        print('Do NOT use amp.')
    if args.local_rank == 0:
        logging.info('NVIDIA APEX {}. AMP {}.'.format(
            'installed' if has_apex else 'not installed',
            'on' if use_amp else 'off'))

    # optionally resume from a checkpoint
    resume_state = {}
    resume_epoch = None
    if args.resume:
        resume_state, resume_epoch = resume_checkpoint(model, args.resume)
    if resume_state and not args.no_resume_opt:
        if 'optimizer' in resume_state:
            if args.local_rank == 0:
                logging.info('Restoring Optimizer state from checkpoint')
            optimizer.load_state_dict(resume_state['optimizer'])
        if use_amp and 'amp' in resume_state and 'load_state_dict' in amp.__dict__:
            if args.local_rank == 0:
                logging.info('Restoring NVIDIA AMP state from checkpoint')
            amp.load_state_dict(resume_state['amp'])
    resume_state = None

    if args.freeze_binary:
        Model_freeze_binary(model)

    if args.distributed:
        if args.sync_bn:
            try:
                if has_apex:
                    model = convert_syncbn_model(model)
                else:
                    model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(
                        model)
                if args.local_rank == 0:
                    logging.info(
                        'Converted model to use Synchronized BatchNorm.')
            except Exception as e:
                logging.error(
                    'Failed to enable Synchronized BatchNorm. Install Apex or Torch >= 1.1'
                )
        if has_apex:
            model = DDP(model, delay_allreduce=True)
        else:
            if args.local_rank == 0:
                logging.info(
                    "Using torch DistributedDataParallel. Install NVIDIA Apex for Apex DDP."
                )
            model = DDP(model,
                        device_ids=[args.local_rank
                                    ])  # can use device str in Torch >= 1.1
        # NOTE: EMA model does not need to be wrapped by DDP

    lr_scheduler, num_epochs = create_scheduler(args, optimizer)
    print(num_epochs)
    # start_epoch = 0 #
    if args.start_epoch is not None:
        # a specified start_epoch will always override the resume epoch
        start_epoch = args.start_epoch
    elif resume_epoch is not None:
        start_epoch = resume_epoch
    if args.reset_lr_scheduler is not None:
        lr_scheduler.base_values = len(
            lr_scheduler.base_values) * [args.reset_lr_scheduler]
        lr_scheduler.step(start_epoch)

    if lr_scheduler is not None and start_epoch > 0:
        lr_scheduler.step(start_epoch)

    if args.local_rank == 0:
        logging.info('Scheduled epochs: {}'.format(num_epochs))

    # Using pruner to get sparse weights
    if args.prune:
        pruner = Pruner_mixed(model, 0, 100, args.pruner)
    else:
        pruner = None

    dataset_train = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100',
                                                  train=True,
                                                  download=True)

    collate_fn = None
    if args.prefetcher and args.mixup > 0:
        collate_fn = FastCollateMixup(args.mixup, args.smoothing,
                                      args.num_classes)

    loader_train = create_loader_CIFAR100(
        dataset_train,
        input_size=data_config['input_size'],
        batch_size=args.batch_size,
        is_training=True,
        use_prefetcher=args.prefetcher,
        rand_erase_prob=args.reprob,
        rand_erase_mode=args.remode,
        rand_erase_count=args.recount,
        color_jitter=args.color_jitter,
        auto_augment=args.aa,
        interpolation='random',
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
        collate_fn=collate_fn,
        is_clean_data=args.clean_train,
    )

    dataset_eval = torchvision.datasets.CIFAR100(root='~/Downloads/CIFAR100',
                                                 train=False,
                                                 download=True)

    loader_eval = create_loader_CIFAR100(
        dataset_eval,
        input_size=data_config['input_size'],
        batch_size=4 * args.batch_size,
        is_training=False,
        use_prefetcher=args.prefetcher,
        interpolation=data_config['interpolation'],
        mean=data_config['mean'],
        std=data_config['std'],
        num_workers=args.workers,
        distributed=args.distributed,
    )

    if args.mixup > 0.:
        # smoothing is handled with mixup label transform
        train_loss_fn = SoftTargetCrossEntropy(
            multiplier=args.softmax_multiplier).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    elif args.smoothing:
        train_loss_fn = LabelSmoothingCrossEntropy(
            smoothing=args.smoothing).cuda()
        validate_loss_fn = nn.CrossEntropyLoss().cuda()
    else:
        train_loss_fn = nn.CrossEntropyLoss().cuda()
        validate_loss_fn = train_loss_fn

    eval_metric = args.eval_metric
    best_metric = None
    best_epoch = None
    saver = None
    saver_last_10_epochs = None
    output_dir = ''
    if args.local_rank == 0:
        output_base = args.output if args.output else './output'
        exp_name = '-'.join([
            datetime.now().strftime("%Y%m%d-%H%M%S"), args.model,
            str(data_config['input_size'][-1])
        ])
        output_dir = get_outdir(output_base, 'train', exp_name)
        decreasing = True if eval_metric == 'loss' else False
        os.makedirs(output_dir + '/Top')
        os.makedirs(output_dir + '/Last')
        saver = CheckpointSaver(
            checkpoint_dir=output_dir + '/Top',
            decreasing=decreasing,
            max_history=10)  # Save the results of the top 10 epochs
        saver_last_10_epochs = CheckpointSaver(
            checkpoint_dir=output_dir + '/Last',
            decreasing=decreasing,
            max_history=10)  # Save the results of the last 10 epochs
        with open(os.path.join(output_dir, 'args.yaml'), 'w') as f:
            f.write(args_text)
            f.write('==============================\n')
            f.write(model.__str__())
            # if pruner:
            #     f.write('\n Sparsity \n')
            #     #f.write(pruner.threshold_dict.__str__())
            #     f.write('\n pruner.start_epoch={}, pruner.end_epoch={}'.format(pruner.start_epoch, pruner.end_epoch))

    tensorboard_writer = SummaryWriter(output_dir)

    try:
        for epoch in range(start_epoch, num_epochs):

            global alpha
            alpha = get_alpha(epoch, args)

            if args.distributed:
                loader_train.sampler.set_epoch(epoch)

            if pruner:
                pruner.on_epoch_begin(epoch)  # pruning

            train_metrics = train_epoch(epoch,
                                        model,
                                        loader_train,
                                        optimizer,
                                        train_loss_fn,
                                        args,
                                        lr_scheduler=lr_scheduler,
                                        saver=saver,
                                        output_dir=output_dir,
                                        use_amp=use_amp,
                                        tensorboard_writer=tensorboard_writer,
                                        pruner=pruner)

            if pruner:
                pruner.print_statistics()

            eval_metrics = validate(model,
                                    loader_eval,
                                    validate_loss_fn,
                                    args,
                                    tensorboard_writer=tensorboard_writer,
                                    epoch=epoch)

            if lr_scheduler is not None:
                # step LR for next epoch
                lr_scheduler.step(epoch + 1, eval_metrics[eval_metric])

            update_summary(epoch,
                           train_metrics,
                           eval_metrics,
                           os.path.join(output_dir, 'summary.csv'),
                           write_header=best_metric is None)

            if saver is not None:
                # save proper checkpoint with eval metric
                save_metric = eval_metrics[eval_metric]
                best_metric, best_epoch = saver.save_checkpoint(
                    model,
                    optimizer,
                    args,
                    epoch=epoch,
                    metric=save_metric,
                    use_amp=use_amp)
            if saver_last_10_epochs is not None:
                # save the checkpoint in the last 5 epochs
                _, _ = saver_last_10_epochs.save_checkpoint(model,
                                                            optimizer,
                                                            args,
                                                            epoch=epoch,
                                                            metric=epoch,
                                                            use_amp=use_amp)

    except KeyboardInterrupt:
        pass
    if best_metric is not None:
        logging.info('*** Best metric: {0} (epoch {1})'.format(
            best_metric, best_epoch))