Ejemplo n.º 1
0
def main(config):
    teacher = Teacher(config['teacher_path']).to(config['device'])

    net = Network(**config['model_args']).to(config['device'])
    data_train, data_val = get_dataset(config['source'])(**config['data_args'])

    optim = torch.optim.Adam(net.parameters(), **config['optimizer_args'])
    scheduler = torch.optim.lr_scheduler.MultiStepLR(
        optim,
        milestones=[mult * config['max_epoch'] for mult in [0.5, 0.75]],
        gamma=0.5)

    wandb.init(project='task-distillation-07',
               config=config,
               id=config['run_name'],
               resume='auto')
    wandb.save(str(Path(wandb.run.dir) / '*.t7'))

    if wandb.run.resumed:
        resume_project(net, optim, scheduler, config)
    else:
        wandb.run.summary['step'] = 0
        wandb.run.summary['epoch'] = 0
        wandb.run.summary['best_epoch'] = 0

    resume_epoch = max(wandb.run.summary['epoch'],
                       wandb.run.summary['best_epoch'])

    for epoch in tqdm.tqdm(range(resume_epoch + 1, config['max_epoch'] + 1),
                           desc='epoch',
                           position=0):
        wandb.run.summary['epoch'] = epoch

        checkpoint_project(net, optim, scheduler, config)

        loss_train = train_or_eval(teacher, net, data_train, optim, True,
                                   config)
        print(loss_train)

        with torch.no_grad():
            loss_val = train_or_eval(teacher, net, data_val, None, False,
                                     config)
        wandb.log({'train/loss_epoch': loss_train, 'val/loss_epoch': loss_val})

        if loss_val < wandb.run.summary.get('best_val_loss', np.inf):
            wandb.run.summary['best_val_loss'] = loss_val
            wandb.run.summary['best_epoch'] = epoch
            torch.save(net.state_dict(), Path(wandb.run.dir) / 'model_best.t7')

        if epoch % 10 == 0:
            torch.save(net.state_dict(),
                       Path(wandb.run.dir) / ('model_%03d.t7' % epoch))
Ejemplo n.º 2
0
def main():
    if not torch.cuda.is_available():
        logging.info('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    logging.info("args = %s", args)
    logging.info("unparsed args = %s", unparsed)

    # prepare dataset
    if args.is_cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(
            args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.is_cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=True,
                                   download=False,
                                   transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir,
                                   train=False,
                                   download=False,
                                   transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=True,
                                  download=False,
                                  transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir,
                                  train=False,
                                  download=False,
                                  transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(train_data,
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              pin_memory=True,
                                              num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(valid_data,
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              pin_memory=True,
                                              num_workers=args.workers)

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    supernet = Network(args.init_channels, CIFAR_CLASSES, args.layers)
    supernet.cuda()

    if args.is_cifar100:
        weight_decay = 5e-4
    else:
        weight_decay = 3e-4
    optimizer = torch.optim.SGD(
        supernet.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=weight_decay,
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, float(args.epochs), eta_min=args.learning_rate_min)

    for epoch in range(args.epochs):
        logging.info('epoch %d lr %e', epoch, scheduler.get_last_lr()[0])

        train_acc, train_obj = train(train_queue, supernet, criterion,
                                     optimizer)
        logging.info('train_acc %f', train_acc)

        valid_top1 = utils.AverageMeter()
        for i in range(args.eval_time):
            supernet.generate_share_alphas()
            ops_alps = supernet.cells[0].ops_alphas
            subnet = supernet.get_sub_net(ops_alps)

            valid_acc, valid_obj = infer(valid_queue, subnet, criterion)
            valid_top1.update(valid_acc)
        logging.info('Mean Valid Acc: %f', valid_top1.avg)
        scheduler.step()

        utils.save(supernet, os.path.join(args.save, 'supernet_weights.pt'))
Ejemplo n.º 3
0
class DQNAgent:

    def __init__(
        self,
        env: UnityEnvironment,
        memory_size: int,
        batch_size: int,
        target_update: int,
        epsilon_decay: float = 1 / 2000,
        max_epsilon: float = 1.0,
        min_epsilon: float = 0.1,
        gamma: float = 0.99,
        ):
        self.brain_name = env.brain_names[0]
        self.brain = env.brains[self.brain_name]
        env_info = env.reset(train_mode=True)[self.brain_name]
        self.env = env
        action_size = self.brain.vector_action_space_size
        state = env_info.vector_observations[0]
        state_size = len(state)
        
        self.obs_dim = state_size
        self.action_dim = 1

        self.memory = ReplayBuffer(self.obs_dim, self.action_dim, memory_size, batch_size)


        self.batch_size = batch_size
        self.target_update = target_update
        self.epsilon_decay = epsilon_decay
        self.max_epsilon = max_epsilon
        self.min_epsilon = min_epsilon
        self.gamma = gamma
        self.epsilon = max_epsilon

        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        
        self.dqn = Network(self.obs_dim, self.action_dim)
        self.dqn_target = Network(self.obs_dim, self.action_dim)
        self.dqn_target.load_state_dict(self.dqn.state_dict())
        self.dqn_target.eval()

        self.optimizer = optim.Adam(self.dqn.parameters(), lr=5e-5)

        self.transition = list()

        self.is_test = False

    def select_action(self, state: np.ndarray) -> np.int64:
        """ Select an action given input """
        if self.epsilon > np.random.random():
            selected_action = np.random.random_integers(0, self.action_dim-1)
        else:
            selected_action = self.dqn(
                torch.FloatTensor(state).to(self.device)
            )
            selected_action = np.argmax(selected_action.detach().cpu().numpy())

        
        if not self.is_test:
            self.transition = [state, selected_action]
        
        return selected_action

    def step(self, action: np.int64) -> Tuple[np.ndarray, np.float64, bool]:
        "Take an action and return environment response"
        env_info = self.env.step(action)[self.brain_name]
        next_state = env_info.vector_observations[0]   
        reward = env_info.rewards[0]                   
        done = env_info.local_done[0]
    
        if not self.is_test:
            self.transition += [reward, next_state, done]
            self.memory.store(*self.transition)

        return next_state, reward, done

    def update_model(self) -> torch.Tensor:
        """ Update model by gradient descent"""
        samples = self.memory.sample_batch()
        loss = self._compute_dqn_loss(samples)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        return loss.item()

    def train(self, num_episode: int, max_iteration: int=1000, plotting_interval: int=400):
        """  train the agent """
        self.is_test = False

        env_info = self.env.reset(train_mode=True)[self.brain_name]
        state = env_info.vector_observations[0]

        update_cnt = 0
        epsilons = []
        losses = []
        avg_losses= []
        scores = []
        avg_scores = []

        for episode in range(num_episode):
            env_info = self.env.reset(train_mode=True)[self.brain_name]
            state = env_info.vector_observations[0]
            score = 0
            for iter in range(max_iteration):
                action = self.select_action(state)
                next_state, reward, done = self.step(action)
                state = next_state
                score += reward
                if done:
                    break

                if len(self.memory) > self.batch_size:
                    loss = self.update_model()
                    losses.append(loss)
                    update_cnt += 1

            avg_losses.append(np.mean(losses))
            losses = []
            self.epsilon = max(
                self.min_epsilon, self.epsilon - (
                    self.max_epsilon - self.min_epsilon
                ) * self.epsilon_decay
            )
            epsilons.append(self.epsilon)
            
            if update_cnt % self.target_update == 0:
                self._target_hard_update()
            scores.append(score)
            epsilons.append(self.epsilon)

            if episode >= 100:
                avg_scores.append(np.mean(scores[-100:]))
            self._plot(episode, scores, avg_scores, avg_losses, epsilons)
        torch.save(self.dqn.state_dict(), "model_weight/dqn.pt")



    def test(self):
        """ Test agent """
        self.is_test = True
        env_info = self.env.reset(train_mode=False)[self.brain_name]
        state = env_info.vector_observations[0]
        done = False
        score = 0

        while not done:
            action = self.select_action(state)
            next_state, reward, done = self.step(action)

            state = next_state
            score += reward

        print("score: ", score)
        self.env.close()


    def _compute_dqn_loss(self, samples: Dict[str, np.ndarray], gamma: float=0.99) -> torch.Tensor:
        """ Compute and return DQN loss"""
        gamma = self.gamma
        device = self.device
        state = torch.FloatTensor(samples["obs"]).to(device)
        next_state = torch.FloatTensor(samples["next_obs"]).to(device)
        action = torch.LongTensor(samples["acts"]).reshape(-1, 1).to(device)
        reward = torch.FloatTensor(samples["rews"]).reshape(-1, 1).to(device)
        done = torch.FloatTensor(samples["done"]).reshape(-1, 1).to(device)
        
        curr_q_value = self.dqn(state).gather(1, action)
            
        next_q_value = self.dqn_target(next_state).max(dim=1, keepdim=True)[0].detach()
        mask = 1 - done
        target = (reward + gamma * next_q_value * mask).to(device)
        loss = F.smooth_l1_loss(curr_q_value, target)

        return loss


    def _target_hard_update(self):
        """ update target network """
        self.dqn_target.load_state_dict(self.dqn.state_dict())

    def _plot(
        self,
        episode :int,
        scores: List[float],
        avg_scores: List[float],
        losses: List[float],
        epsilons: List[float]
    ):
        """ Plot the training process"""
        plt.figure(figsize=(20, 5))
        plt.subplot(141)
        if len(avg_scores) > 0:
            plt.title("Average reward per 100 episodes. Score: %s" % (avg_scores[-1]))
        else:
            plt.title("Average reward over 100 episodes.")
        plt.plot([100 + i for i in range(len(avg_scores))], avg_scores)
        plt.subplot(142)
        plt.title("episode %s. Score: %s" % (episode, np.mean(scores[-10:])))
        plt.plot(scores)
        plt.subplot(143)
        plt.title('Loss')
        plt.plot(losses)
        plt.subplot(144)
        plt.title('epsilons')
        plt.plot(epsilons)
        plt.savefig('plots/dqn_result.png')
Ejemplo n.º 4
0
def main():
    if not torch.cuda.is_available():
        print('No GPU device available')
        sys.exit(1)
    np.random.seed(args.seed)
    cudnn.benchmark = True
    torch.manual_seed(args.seed)
    cudnn.enabled = True
    torch.cuda.manual_seed(args.seed)
    print("args = %s", args)
    print("unparsed args = %s", unparsed)

    # prepare dataset
    if args.cifar100:
        train_transform, valid_transform = utils._data_transforms_cifar100(args)
    else:
        train_transform, valid_transform = utils._data_transforms_cifar10(args)
    if args.cifar100:
        train_data = dset.CIFAR100(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR100(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)
    else:
        train_data = dset.CIFAR10(root=args.tmp_data_dir, train=True, download=True, transform=train_transform)
        valid_data = dset.CIFAR10(root=args.tmp_data_dir, train=False, download=True, transform=valid_transform)

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=args.workers)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size, shuffle=False, pin_memory=True, num_workers=args.workers)

    ood_queues = {}
    for k in ['svhn', 'lsun_resized', 'imnet_resized']:
        ood_path = os.path.join(args.ood_dir, k)
        dset_ = dset.ImageFolder(ood_path, valid_transform)
        loader = torch.utils.data.DataLoader(
            dset_, batch_size=args.batch_size, shuffle=False,
            pin_memory=True, num_workers=args.workers
        )
        ood_queues[k] = loader

    # build Network
    criterion = nn.CrossEntropyLoss()
    criterion = criterion.cuda()
    supernet = Network(
        args.init_channels, CIFAR_CLASSES, args.layers,
        combine_method=args.feat_comb, is_cosine=args.is_cosine,
    )
    supernet.cuda()
    # print(len(supernet.cells))
    ckpt = torch.load(args.load_at)
    print(args.load_at)
    supernet.load_state_dict(ckpt)
    supernet.generate_share_alphas()
    # alphas = torch.Tensor([
    #     [0., 1., 1.],
    #     [0., 1., 0.],
    #     [0., 1., 0.],
    #     [0., 1., 1.],
    #     [0., 1., 1.],
    #     [0., 1., 1.],
    #     [0., 1., 0.],
    #     [0., 1., 0.],
    #     [0., 1., 1.],
    #     [0., 1., 0.],
    #     [0., 1., 0.],
    #     [0., 1., 1.],
    #     [0., 1., 0.],
    #     [0., 1., 1.]
    # ]).cuda()
    # for i in range(8):
    #     supernet.cells[i].ops_alphas = alphas
    alphas = supernet.cells[0].ops_alphas
    print(alphas)
    out_dir = './results/{}/eval_out/{}'.format(args.load_at.split('/')[2], args.seed)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    torch.save(alphas, os.path.join(out_dir, 'alphas.pt'))
    with open(os.path.join(out_dir, 'alphas.txt'), 'w') as f:
        for i in alphas.cpu().detach().numpy():
            for j in i:
                f.write('{:d}'.format(int(j)))
            f.write('\n')

    if args.cifar100:
        weight_decay = 5e-4
    else:
        weight_decay = 3e-4
    optimizer = torch.optim.SGD(
        supernet.parameters(),
        args.learning_rate,
        momentum=args.momentum,
        weight_decay=weight_decay,
    )
    # scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, float(args.epochs), eta_min=0)

    valid_acc, _ = infer(valid_queue, supernet, criterion)
    print('valid_acc {:.2f}'.format(valid_acc))
    lg_aucs, sm_aucs, ent_aucs = ood_eval(valid_queue, ood_queues, supernet, criterion)
    with open(os.path.join(out_dir, 'before.txt'), 'w') as f:
        f.write('-'.join([str(valid_acc), str(lg_aucs), str(sm_aucs), str(ent_aucs)]))

    if args.fine_tune:
        for epoch in range(args.epochs):
            # scheduler.step()
            print('epoch {} lr {:.4f}'.format(epoch, 0.001))#scheduler.get_lr()[0]))

            train_acc, _ = train(train_queue, supernet, criterion, optimizer)
            print('train_acc {:.2f}'.format(train_acc))

            valid_acc, _ = infer(valid_queue, supernet, criterion)
            print('valid_acc {:.2f}'.format(valid_acc))

        lg_aucs, sm_aucs, ent_aucs = ood_eval(valid_queue, ood_queues, supernet, criterion)
        with open(os.path.join(out_dir, 'after.txt'), 'w') as f:
            f.write('-'.join([str(valid_acc), str(lg_aucs), str(sm_aucs), str(ent_aucs)]))