Example #1
0
    def loop_finetune(self, pretrained_state, max_lp, max_ep, rewarded_ep,
                      op_cfg, sc_cfg, candidate_ratio, finetune_lsmooth):
        self.g_tb_lg.add_scalars('probs', self.agent.get_prob_dict(), -1)
        self.g_tb_lg.add_histogram('probs_dist', self.agent.get_prob_tensor(),
                                   -1)
        [
            self.g_tb_lg.add_scalar('ppo_step', self.agent.max_training_times,
                                    t) for t in [-1, max_lp, -1]
        ]

        max_it = self.auged_sub_train_iters
        loader = self.auged_sub_train_ld
        # assert max_it == len(loader)
        agent_param_his = []

        best_rewards_mean = 0
        best_rewards_lp = 0
        best_agent_state = {}
        candidate_ep = max(round(max_ep * candidate_ratio), 1)
        loop_speed = AverageMeter(4)
        crit = self.criterion if finetune_lsmooth else F.cross_entropy
        for lp in range(max_lp):
            lp_str = f'%{len(str(max_lp))}d' % (lp + 1)
            lp_start_t = time.time()
            self.model.load_state_dict(pretrained_state['model'])
            self.model.train()
            op, sc = self.create_op_sc(self.model, op_cfg, sc_cfg, max_it)
            op: torch.optim.optimizer.Optimizer
            op.load_state_dict(pretrained_state['op'])

            epoch_speed = AverageMeter(1)
            acc1s = []
            for ep in range(max_ep):
                ep_str = f'%{len(str(max_ep))}d' % (ep + 1)
                ep_start_t = time.time()
                for it, (inp, tar, op_indices) in enumerate(loader):
                    global_it = ep * max_it + it
                    self.agent.record(op_indices)
                    inp, tar = inp.cuda(), tar.cuda()
                    loss = crit(self.model(inp), tar)
                    op.zero_grad()
                    loss.backward()
                    if self.model_grad_clip is not None:
                        total_norm = torch.nn.utils.clip_grad_norm_(
                            self.model.parameters(), self.model_grad_clip)
                    else:
                        total_norm = -233
                    clipped_norm = torch.cat([
                        p.grad.data.view(-1) for p in self.model.parameters()
                    ]).abs_().norm()

                    sc.step()  # sc.step() before op.step()
                    lr = sc.get_lr()[0]
                    clipped_lr = lr * (clipped_norm / total_norm)
                    op.step()

                val_loss, val_acc1, val_acc5 = self.val()
                self.model.train()
                acc1s.append(val_acc1)
                if lp % 20 == 0:
                    if ep == 0:
                        self.lg.info(f'==> at {self.exp_root}')
                    self.g_tb_lg.add_scalars('rk0_ft_v_loss',
                                             {f'loop_{lp}': val_loss}, ep)
                    self.g_tb_lg.add_scalars('rk0_ft_v_acc1',
                                             {f'loop_{lp}': val_acc1}, ep)
                    self.g_tb_lg.add_scalars('rk0_ft_v_acc5',
                                             {f'loop_{lp}': val_acc5}, ep)

                epoch_speed.avg = time.time() - ep_start_t
                remain_time, finish_time = epoch_speed.time_preds(max_ep - ep -
                                                                  1)
                self.lg.info(f'lp[{lp_str}/{max_lp}], ep[{ep_str}/{max_ep}]'
                             f' vacc1: {float(val_acc1):5.2f},'
                             f' verr1: {float(100.-val_acc1):5.2f},'
                             f' time cost: {time.time()-ep_start_t:.3f},'
                             f' op_freq.s: {self.agent.op_freq.sum()},'
                             f' rem-t: {remain_time} ({finish_time})')

            acc1s = acc1s[-candidate_ep:]
            rewarded_acc1s = sorted(acc1s)[-rewarded_ep:]
            reward = sum(rewarded_acc1s) / len(rewarded_acc1s)
            rewards = sync_vals(self.dist, reward, fmt=None)
            rewards_mean = rewards.mean().item()
            if self.agent.initial_baseline is not None:
                d = f'{rewards_mean-self.agent.initial_baseline:.3f}'
            else:
                d = None

            if rewards_mean > best_rewards_mean:
                best_rewards_mean = rewards_mean
                best_rewards_lp = lp - 1
                best_agent_state = self.agent.state_dict()
                best_agent_state = {
                    'first_param':
                    best_agent_state['first_param'].data.clone(),
                    'second_param':
                    best_agent_state['second_param'].data.clone()
                }

                if self.dist.is_master() and d is not None:
                    for root, dirs, files in os.walk(
                            self.best_agent_ckpt_root):
                        for f in files:
                            os.remove(os.path.join(root, f))
                    torch.save({
                        'lp': lp,
                        'agent': best_agent_state,
                    },
                               os.path.join(
                                   self.best_agent_ckpt_root,
                                   f'after_lp{best_rewards_lp}_d{d}.pth.tar'))

            if lp == 0:
                self.agent.set_baselines(initial_baseline=rewards_mean,
                                         running_baseline=reward)
                [
                    self.g_tb_lg.add_scalars('reward',
                                             {f'g_ini_bsln': rewards_mean}, t)
                    for t in [0, max_lp // 2, max_lp - 1]
                ]
                [
                    self.l_tb_lg.add_scalars(
                        'reward', {f'rk{self.dist.rank}_ini_run_bsln': reward},
                        t) for t in [0, max_lp // 2, max_lp - 1]
                ]

            ppo_step_times = self.agent.step(reward=reward)
            self.g_tb_lg.add_scalar('agent_lr',
                                    self.agent.scheduler.get_lr()[0], lp)

            loop_speed.update(time.time() - lp_start_t)
            remain_time, finish_time = loop_speed.time_preds(max_lp - lp - 1)
            self.lg.info(
                f'==> loop[{lp_str}/{max_lp}],'
                f' time cost: {(time.time()-lp_start_t) / 60:.2f} min,'
                f' rem-t[{remain_time}] ({finish_time}),'
                f' rew={rewards}')
            if self.dist.is_master():
                agent_param_his.append(self.agent.get_params_as_list())

            self.g_tb_lg.add_scalar('ppo_step', ppo_step_times, lp)
            self.l_tb_lg.add_scalars(
                'reward',
                {f'rk{self.dist.rank}_run_bsln': self.agent.running_baseline},
                lp)
            self.l_tb_lg.add_scalars(
                'advance', {f'rk{self.dist.rank}_adv': self.agent.advance_val},
                lp)

            self.g_tb_lg.add_scalars('probs', self.agent.get_prob_dict(), lp)
            self.g_tb_lg.add_histogram('probs_dist',
                                       self.agent.get_prob_tensor(), lp)

            if self.dist.is_master():
                torch.save(
                    {
                        'lp': lp,
                        'agent': self.agent.state_dict(),
                    },
                    os.path.join(
                        self.agents_ckpt_root,
                        f'lp{lp}_d{d}_rew_mean{rewards_mean:.2f}.pth.tar'))

            torch.cuda.empty_cache()

            if self.dist.is_master():
                f_name = os.path.join(self.ckpt_root, 'agent_param_his.json')
                self.lg.info(f'dump agent params into {f_name}')
                with open(f_name, 'w') as fp:
                    json.dump(agent_param_his, fp)
                # if lp == 0:
                #     self.lg.info(f'dumped list[0]: {agent_param_his[0]}')

            self.dist.barrier()
            if not os.path.exists(self.early_stop_root):
                break

        [
            self.meta_tb_lg.add_scalar('best_rew_mean', best_rewards_mean, t)
            for t in [0, best_rewards_lp, max_lp]
        ]
        [
            self.g_tb_lg.add_scalar('best_rew_mean', best_rewards_mean, t)
            for t in [0, best_rewards_lp, max_lp]
        ]
        return {'lp': best_rewards_lp, 'agent': best_agent_state}
Example #2
0
    def _train_with_aug(self,
                        max_iters,
                        loader,
                        max_ep,
                        op_cfg,
                        sc_cfg,
                        sync_mid,
                        lsmooth,
                        save_mode='best',
                        prefix='pre'):
        # assert max_iters == len(loader)
        self.model.train()

        max_it = max_iters
        max_global_it = max_ep * max_it
        train_log_freq = max_it // 10
        test_freqs = [self.test_freq * 32, self.test_freq]

        speed = AverageMeter(max_it)
        tr_loss, tr_acc1, tr_acc5 = AverageMeter(train_log_freq), AverageMeter(
            train_log_freq), AverageMeter(train_log_freq)

        op, sc = self.create_op_sc(self.model,
                                   op_cfg,
                                   sc_cfg,
                                   iters_per_epoch=max_it)
        op: Optimizer
        sc: LRScheduler
        best_acc1 = 0
        start_train_t = time.time()
        crit = self.criterion if lsmooth else F.cross_entropy
        for ep in range(max_ep):
            ep_str = f'%{len(str(max_ep))}d' % (ep + 1)
            is_late = int(ep >= 0.75 * max_ep)
            test_freq = test_freqs[is_late]
            if ep % 32 == 0:
                self.lg.info(f'==> at {self.exp_root}')

            last_t = time.time()
            for it, tup in enumerate(loader):
                if len(tup) == 3:
                    inp, tar, _ = tup
                else:
                    inp, tar = tup
                it_str = f'%{len(str(max_it))}d' % (it + 1)
                global_it = ep * max_it + it
                data_t = time.time()

                if global_it == 1:
                    for i in range(self.dist.world_size):
                        if self.dist.rank == i:
                            print(f'rk[{i:2d}] dist test')
                        self.dist.barrier()

                inp, tar = inp.cuda(), tar.cuda()
                cuda_t = time.time()

                logits = self.model(inp)
                loss = crit(logits, tar)
                tr_loss.update(loss.item())
                op.zero_grad()
                loss.backward()
                if self.dist_training:
                    pass
                if self.model_grad_clip is not None:
                    total_norm = torch.nn.utils.clip_grad_norm_(
                        self.model.parameters(), self.model_grad_clip)
                else:
                    total_norm = -233
                clipped_norm = torch.cat([
                    p.grad.data.view(-1) for p in self.model.parameters()
                ]).abs_().norm()

                sc.step()  # sc.step() before op.step()
                lr = sc.get_lr()[0]
                clipped_lr = lr * (clipped_norm / total_norm)

                op.step()
                acc1, acc5 = accuracy(logits, tar)
                op_t = time.time()

                total_r = tar.shape[0] / 128
                tr_acc1.update(val=acc1, num=total_r)
                tr_acc5.update(val=acc5, num=total_r)

                if global_it % test_freq == 0 or global_it == max_global_it - 1:
                    test_loss, test_acc1, test_acc5 = self.test()
                    test_t = time.time()
                    self.model.train()
                    is_best = test_acc1 >= best_acc1
                    best_acc1 = max(test_acc1, best_acc1)

                    if self.dist.is_master() and it + 1 == max_it:
                        remain_time, finish_time = speed.time_preds(
                            max_global_it - global_it - 1)
                        self.lg.info(
                            f'ep[{ep_str}/{max_ep}], it[{it_str}/{max_it}]:'
                            f' tr-err1[{100-tr_acc1.last:5.2f}] ({100-tr_acc1.avg:5.2f}),'
                            f' tr-loss[{tr_loss.last:.4f}] ({tr_loss.avg:.4f}),'
                            f' te-err1[{100-test_acc1:5.2f}],'
                            f' te-loss[{test_loss:.4f}],\n'
                            f' data[{data_t-last_t:.3f}],'
                            f' cuda[{cuda_t-data_t:.3f}],'
                            f' bp[{op_t-cuda_t:.3f}],'
                            f' te[{test_t-op_t:.3f}]'
                            f' rem-t[{remain_time}] ({finish_time})'
                            f' lr[{lr:.4g}] ({clipped_lr:.4g})')

                    state = {
                        'model': self.model.state_dict(),
                        'op': op.state_dict(),
                        'last_iter': global_it,
                    }

                    model_ckpt_path = os.path.join(
                        self.ckpt_root,
                        f'rk{self.dist.rank}_{prefix}_{save_mode}.pth.tar')
                    if save_mode == 'best' and is_best:
                        self.lg.info(
                            f'==> saving best model ckpt (err{100-test_acc1:.3f}) at {os.path.abspath(model_ckpt_path)}...'
                        )
                        torch.save(state, model_ckpt_path)
                    elif save_mode == 'last':
                        torch.save(state, model_ckpt_path)

                speed.update(time.time() - last_t)
                last_t = time.time()

        if self.dist.world_size > 1:
            test_loss, test_acc1, test_acc5 = self.test()
            acc1_ts: torch.Tensor = sync_vals(self.dist, test_acc1, None)
            mid_rank = acc1_ts.argsort()[self.dist.world_size // 2].item()
            mid_ckpt_path = os.path.join(
                self.ckpt_root,
                f'midrk{mid_rank}_{prefix}_enderr{100-acc1_ts[mid_rank].item():.2f}.pth.tar'
            )
            if self.dist.rank == mid_rank:
                torch.save(
                    {
                        'model': self.model.state_dict(),
                        'op': op.state_dict(),
                    }, mid_ckpt_path)
            self.dist.barrier()

            if sync_mid:
                mid_ckpt = torch.load(mid_ckpt_path, map_location='cpu')
                self.model.load_state_dict(mid_ckpt['model'])
                op.load_state_dict(mid_ckpt['op'])

            best_errs: torch.Tensor = sync_vals(self.dist, 100 - best_acc1,
                                                None)
            best_err: float = best_errs.mean().item()
            self.lg.info(
                f'==> {prefix}-training finished, mid rank={mid_rank},'
                f' total time cost: {(time.time()-start_train_t)/60:.2f} min,'
                f' test err @1: mean={best_err:.3f}')
        else:
            best_err = 100 - best_acc1
            self.lg.info(
                f'==> {prefix}-training finished,'
                f' total time cost: {(time.time()-start_train_t)/60:.2f} min,'
                f' test err @1: {100-best_acc1:.3f}')

        [
            self.meta_tb_lg.add_scalar(f'{prefix}_best_err', best_err, t)
            for t in [0, max_ep]
        ]
        [
            self.g_tb_lg.add_scalar(f'{prefix}_best_err', best_err, t)
            for t in [0, max_ep]
        ]
        return {
            'model': self.model.state_dict(),
            'op': op.state_dict(),
            'last_iter': max_global_it
        }