예제 #1
0
class Agent():
    def __init__(self, test=False):
        # device
        if torch.cuda.is_available():
            self.device = torch.device('cuda')
        else :
            self.device = torch.device('cpu')
        
        self.model = MLP(state_dim=4,action_num=2,hidden_dim=256).to(self.device)  
        if test:
            self.load('./pg_best.cpt')        
        # discounted reward
        self.gamma = 0.99 
        # optimizer
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=3e-3)
        # saved rewards and actions
        self.memory = Memory()
        self.tensorboard = TensorboardLogger('./')
    def save(self, save_path):
        print('save model to', save_path)
        torch.save(self.model.state_dict(), save_path)
    def load(self, load_path):
        print('load model from', load_path)
        self.model.load_state_dict(torch.load(load_path))
    def act(self,x,test=False):
        if not test:
            # boring type casting
            x = ((torch.from_numpy(x)).unsqueeze(0)).float().to(self.device)
            # stochastic sample
            action_prob = self.model(x)
            dist = torch.distributions.Categorical(action_prob)
            action = dist.sample()
            # memory log_prob
            self.memory.logprobs.append(dist.log_prob(action))
            return action.item()    
        else :
            self.model.eval()
            x = ((torch.from_numpy(x)).unsqueeze(0)).float().to(self.device)
            with torch.no_grad():
                action_prob = self.model(x)
                # a = np.argmax(action_prob.cpu().numpy())
                dist = torch.distributions.Categorical(action_prob)
                action = dist.sample()
                return action.item()
    def collect_data(self, state, action, reward):
        self.memory.actions.append(action)
        self.memory.rewards.append(torch.tensor(reward))
        self.memory.states.append(state)
    def clear_data(self):
        self.memory.clear_memory()

    def update(self):
        R = 0
        advantage_function = []        
        for t in reversed(range(0, len(self.memory.rewards))):
            R = R * self.gamma + self.memory.rewards[t]
            advantage_function.insert(0, R)

        # turn rewards to pytorch tensor and standardize
        advantage_function = torch.Tensor(advantage_function).to(self.device)
        advantage_function = (advantage_function - advantage_function.mean()) / (advantage_function.std() + np.finfo(np.float32).eps)

        policy_loss = []
        for log_prob, reward in zip(self.memory.logprobs, advantage_function):
            policy_loss.append(-log_prob * reward)
        # Update network weights
        self.optimizer.zero_grad()
        loss = torch.cat(policy_loss).sum()
        loss.backward()
        self.optimizer.step() 
        # boring log
        self.tensorboard.scalar_summary("loss", loss.item())
        self.tensorboard.update()
예제 #2
0
def train(opts):

    device = torch.device("cuda" if use_cuda else "cpu")

    if opts.arch == 'small':
        channels = [32, 32, 32, 10]
    elif opts.arch == 'large':
        channels = [256, 128, 64, 32]
    else:
        raise NotImplementedError('Unknown model architecture')

    if opts.mode == 'train_mnist':
        train_loader, valid_loader = get_mnist_loaders(opts.data_dir,
                                                       opts.bsize,
                                                       opts.nworkers,
                                                       opts.sigma, opts.alpha)
        model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels)
    elif opts.mode == 'train_cifar':
        train_loader, valid_loader = get_cifar_loaders(opts.data_dir,
                                                       opts.bsize,
                                                       opts.nworkers,
                                                       opts.sigma, opts.alpha)
        model = CAE(3, 10, 32, opts.n_prototypes, opts.decoder_arch, channels)
    elif opts.mode == 'train_fmnist':
        train_loader, valid_loader = get_fmnist_loaders(
            opts.data_dir, opts.bsize, opts.nworkers, opts.sigma, opts.alpha)
        model = CAE(1, 10, 28, opts.n_prototypes, opts.decoder_arch, channels)
    else:
        raise NotImplementedError('Unknown train mode')

    if opts.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(),
                                     lr=opts.lr,
                                     weight_decay=opts.wd)
    else:
        raise NotImplementedError("Unknown optim type")
    criterion = nn.CrossEntropyLoss()

    start_n_iter = 0
    # for choosing the best model
    best_val_acc = 0.0

    model_path = os.path.join(opts.save_path, 'model_latest.net')
    if opts.resume and os.path.exists(model_path):
        # restoring training from save_state
        print('====> Resuming training from previous checkpoint')
        save_state = torch.load(model_path, map_location='cpu')
        model.load_state_dict(save_state['state_dict'])
        start_n_iter = save_state['n_iter']
        best_val_acc = save_state['best_val_acc']
        opts = save_state['opts']
        opts.start_epoch = save_state['epoch'] + 1

    model = model.to(device)

    # for logging
    logger = TensorboardLogger(opts.start_epoch, opts.log_iter, opts.log_dir)
    logger.set(['acc', 'loss', 'loss_class', 'loss_ae', 'loss_r1', 'loss_r2'])
    logger.n_iter = start_n_iter

    for epoch in range(opts.start_epoch, opts.epochs):
        model.train()
        logger.step()
        valid_sample = torch.stack([
            valid_loader.dataset[i][0]
            for i in random.sample(range(len(valid_loader.dataset)), 10)
        ]).to(device)

        for batch_idx, (data, target) in enumerate(train_loader):
            acc, loss, class_error, ae_error, error_1, error_2 = run_iter(
                opts, data, target, model, criterion, device)

            # optimizer step
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm)
            optimizer.step()

            logger.update(acc, loss, class_error, ae_error, error_1, error_2)

        val_loss, val_acc, val_class_error, val_ae_error, val_error_1, val_error_2, time_taken = evaluate(
            opts, model, valid_loader, criterion, device)
        # log the validation losses
        logger.log_valid(time_taken, val_acc, val_loss, val_class_error,
                         val_ae_error, val_error_1, val_error_2)
        print('')

        # Save the model to disk
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            save_state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'n_iter': logger.n_iter,
                'opts': opts,
                'val_acc': val_acc,
                'best_val_acc': best_val_acc
            }
            model_path = os.path.join(opts.save_path, 'model_best.net')
            torch.save(save_state, model_path)
            prototypes = model.save_prototypes(opts.save_path,
                                               'prototypes_best.png')
            x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0)
            logger.writer.add_image('Prototypes (best)', x, epoch)

        save_state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'n_iter': logger.n_iter,
            'opts': opts,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc
        }
        model_path = os.path.join(opts.save_path, 'model_latest.net')
        torch.save(save_state, model_path)
        prototypes = model.save_prototypes(opts.save_path,
                                           'prototypes_latest.png')
        x = torchvision.utils.make_grid(prototypes, nrow=10, pad_value=1.0)
        logger.writer.add_image('Prototypes (latest)', x, epoch)
        ae_samples = model.get_decoded_pairs_grid(valid_sample)
        logger.writer.add_image('AE_samples_latest', ae_samples, epoch)