Example #1
0
def main():
    parser = argparse.ArgumentParser(description="AAE")
    parser.add_argument("--num_epochs", type=int, default=100)
    parser.add_argument("--batch_size", type=int, default=64)
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--data_root", type=str, default="./data")
    parser.add_argument("--data_name", type=str, default="mnist")
    parser.add_argument("--distribution", type=str, default="gaussian")
    parser.add_argument("--image_size", type=int, default=32)
    parser.add_argument("--image_channels", type=int, default=1)
    parser.add_argument("--latent_dim", type=int, default=2)
    parser.add_argument("--num_classes", type=int, default=10)
    opt = parser.parse_args()

    os.makedirs("./outputs/encode", exist_ok=True)
    os.makedirs("./outputs/decode", exist_ok=True)
    os.makedirs("./weights", exist_ok=True)


    encoder = Encoder(opt.image_size, opt.image_channels, opt.latent_dim).to(opt.device)
    decoder = Decoder(opt.image_size, opt.image_channels, opt.latent_dim).to(opt.device)
    discriminator = Discriminator(opt.latent_dim, opt.num_classes, True).to(opt.device)

    for epoch in range(opt.num_epochs):
        reconstruct_loss, e_loss, d_loss = train(encoder, decoder, discriminator, opt)
        print("reconstruct loss: {:.4f} encorder loss: {:.4f} discriminator loss: {:.4f}".format(reconstruct_loss, e_loss, d_loss))
        eval_encoder("./outputs/encode/{}.jpg".format(epoch), encoder, opt)
        eval_decoder("./outputs/decode/{}.jpg".format(epoch), decoder, opt)

    torch.save(encoder.state_dict(), "./weights/encoder.pth")
    torch.save(decoder.state_dict(), "./weights/decoder.pth")
    torch.save(discriminator.state_dict(), "./weights/discriminator.pth")
Example #2
0
def load():

    encoder_net = Encoder(vocab_enc, 150, 200, 1, 0.3).to("cpu")
    decoder_net = Decoder(
        vocab_dec,
        150,
        200,
        vocab_dec,
        1,
        0.3,
    ).to("cpu")
    encoder_net.state_dict(
        torch.load("/home/aradhya/Desktop/hacks/model_for_faq_encoder.pt"))
    decoder_net.state_dict(
        torch.load("/home/aradhya/Desktop/hacks/model_for_faq_decoder.pt"))

    return encoder_net, decoder_net
Example #3
0
def main():
    parser = argparse.ArgumentParser(description='Implementation of SimCLR')
    parser.add_argument('--EPOCHS',
                        default=10,
                        type=int,
                        help='Number of epochs for training')
    parser.add_argument('--BATCH_SIZE',
                        default=64,
                        type=int,
                        help='Batch size')
    parser.add_argument('--TEMP',
                        default=0.5,
                        type=float,
                        help='Temperature parameter for NT-Xent')
    parser.add_argument(
        '--LOG_INT',
        default=100,
        type=int,
        help='How many batches to wait before logging training status')
    parser.add_argument('--DISTORT_STRENGTH',
                        default=0.5,
                        type=float,
                        help='Strength of colour distortion')
    parser.add_argument('--SAVE_NAME', default='model')
    args = parser.parse_args()
    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    online_transform = transforms.Compose([
        transforms.RandomResizedCrop((32, 32)),
        transforms.RandomHorizontalFlip(),
        get_color_distortion(s=args.DISTORT_STRENGTH),
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    trainset = CIFAR10_new(root='./data',
                           train=True,
                           download=True,
                           transform=online_transform)

    # Need to drop last minibatch to prevent matrix multiplication erros
    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.BATCH_SIZE,
                                               shuffle=True,
                                               drop_last=True)

    model = Encoder().to(device)
    optimizer = optim.Adam(model.parameters())
    loss_func = losses.NTXentLoss(args.TEMP)
    for epoch in range(args.EPOCHS):
        train(args, model, device, train_loader, optimizer, loss_func, epoch)

    torch.save(model.state_dict(), './ckpt/{}.pth'.format(args.SAVE_NAME))
Example #4
0
def main(args):
    # load datasets
    vocab = pickle.load(open(args.vocab_path, 'rb'))
    train_data = get_loader(args.json_file,
                            args.mat_file,
                            vocab,
                            args.batch_size,
                            shuffle=True,
                            num_workers=args.num_workers)
    # get vocab
    # build model
    if args.encoder:
        encoder = Encoder(args.embed_size, True)

    decoder = Decoder(args.in_features, len(vocab),
                      args.embed_size, args.hidden_size)
    # define loss and optimizer
    criterion = nn.CrossEntropyLoss()
    params = list(decoder.parameters())
    if args.encoder:
        params = list(decoder.parameters()) + list(encoder.cnn.fc.parameters())
    optimizer = optim.Adam(params, lr=args.learning_rate)
    # train
    total_step = len(train_data)
    for epoch in range(args.num_epochs):
        for i, (images, captions, lengths) in enumerate(train_data):
            if args.encoder:
                images_features = encoder(Variable(images))
            else:
                images_features = Variable(images)
            captions = Variable(captions)
            targets = pack_padded_sequence(
                captions, lengths, batch_first=True)[0]

            decoder.zero_grad()
            outputs = decoder(images_features, captions, lengths)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            if i % args.disp_step == 0:
                print('Epoch [%d/%d], step [%d/%d], loss: %.4f, Perplexity: %5.4f'
                      % (epoch, args.num_epochs, i, total_step, loss.data[0], np.exp(loss.data[0])))

            # Save the models
            if (i + 1) % args.save_step == 0:
                torch.save(decoder.state_dict(),
                           os.path.join(args.model_path,
                                        'decoder-%d-%d.pkl' % (epoch + 1, i + 1)))
                torch.save(encoder.state_dict(),
                           os.path.join(args.model_path,
                                        'encoder-%d-%d.pkl' % (epoch + 1, i + 1)))
Example #5
0
        evaluate(model,
                 vgg,
                 dataloader_val,
                 criterion,
                 old_epoch + epoch,
                 gram_ix=args.gram_ix,
                 split='Val')

        if (epoch + 1) % 10 == 0:
            model_name = model_name.split('-e=')[0] + '-e={}'.format(
                old_epoch + epoch)
            PATH = f'./models/{model_name}.pt'
            torch.save(
                {
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                }, PATH)
            # torch.save(model.state_dict(), PATH[:-3] + '-eval.pt')

    model_name = model_name.split('-e=')[0] + '-e={}'.format(old_epoch +
                                                             args.epochs)
    PATH = f'./models/{model_name}.pt'
    torch.save(
        {
            'epoch': args.epochs,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
        }, PATH)
    # torch.save(model.state_dict(), PATH[:-3] + '-eval.pt')
Example #6
0
class Trainer():
    def __init__(self, params, experience_replay_buffer,metrics,results_dir,env):
        self.parms = params     
        self.D = experience_replay_buffer  
        self.metrics = metrics
        self.env = env
        self.tested_episodes = 0

        self.statistics_path = results_dir+'/statistics' 
        self.model_path = results_dir+'/model' 
        self.video_path = results_dir+'/video' 
        self.rew_vs_pred_rew_path = results_dir+'/rew_vs_pred_rew'
        self.dump_plan_path = results_dir+'/dump_plan'
        
        #if folder do not exists, create it
        os.makedirs(self.statistics_path, exist_ok=True) 
        os.makedirs(self.model_path, exist_ok=True) 
        os.makedirs(self.video_path, exist_ok=True) 
        os.makedirs(self.rew_vs_pred_rew_path, exist_ok=True) 
        os.makedirs(self.dump_plan_path, exist_ok=True) 
        

        # Create models
        self.transition_model = TransitionModel(self.parms.belief_size, self.parms.state_size, self.env.action_size, self.parms.hidden_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device)
        self.observation_model = ObservationModel(self.parms.belief_size, self.parms.state_size, self.parms.embedding_size, self.parms.activation_function).to(device=self.parms.device)
        self.reward_model = RewardModel(self.parms.belief_size, self.parms.state_size, self.parms.hidden_size, self.parms.activation_function).to(device=self.parms.device)
        self.encoder = Encoder(self.parms.embedding_size,self.parms.activation_function).to(device=self.parms.device)
        self.param_list = list(self.transition_model.parameters()) + list(self.observation_model.parameters()) + list(self.reward_model.parameters()) + list(self.encoder.parameters()) 
        self.optimiser = optim.Adam(self.param_list, lr=0 if self.parms.learning_rate_schedule != 0 else self.parms.learning_rate, eps=self.parms.adam_epsilon)
        self.planner = MPCPlanner(self.env.action_size, self.parms.planning_horizon, self.parms.optimisation_iters, self.parms.candidates, self.parms.top_candidates, self.transition_model, self.reward_model,self.env.action_range[0], self.env.action_range[1])

        global_prior = Normal(torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device), torch.ones(self.parms.batch_size, self.parms.state_size, device=self.parms.device))  # Global prior N(0, I)
        self.free_nats = torch.full((1, ), self.parms.free_nats, dtype=torch.float32, device=self.parms.device)  # Allowed deviation in KL divergence

    def load_checkpoints(self):
        self.metrics = torch.load(self.model_path+'/metrics.pth')
        model_path = self.model_path+'/best_model'
        os.makedirs(model_path, exist_ok=True) 
        files = os.listdir(model_path)
        if files:
            checkpoint = [f for f in files if os.path.isfile(os.path.join(model_path, f))]
            model_dicts = torch.load(os.path.join(model_path, checkpoint[0]),map_location=self.parms.device)
            self.transition_model.load_state_dict(model_dicts['transition_model'])
            self.observation_model.load_state_dict(model_dicts['observation_model'])
            self.reward_model.load_state_dict(model_dicts['reward_model'])
            self.encoder.load_state_dict(model_dicts['encoder'])
            self.optimiser.load_state_dict(model_dicts['optimiser'])  
            print("Loading models checkpoints!")
        else:
            print("Checkpoints not found!")


    def update_belief_and_act(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf,explore=False):
        # Infer belief over current state q(s_t|o≤t,a<t) from the history
        encoded_obs = self.encoder(observation).unsqueeze(dim=0).to(device=self.parms.device)       
        belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs)  # Action and observation need extra time dimension
        belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
        action,pred_next_rew,_,_,_ = self.planner(belief, posterior_state,explore)  # Get action from planner(q(s_t|o≤t,a<t), p)      
        
        if explore:
            action = action + self.parms.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
        action.clamp_(min=min_action, max=max_action)  # Clip action range
        next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu())  # If single env is istanceted perform single action (get item from list), else perform all actions
        
        return belief, posterior_state, action, next_observation, reward, done,pred_next_rew 
    
    def fit_buffer(self,episode):
        ####
        # Fit data taken from buffer 
        ######

        # Model fitting
        losses = []
        tqdm.write("Fitting buffer")
        for s in tqdm(range(self.parms.collect_interval)):

            # Draw sequence chunks {(o_t, a_t, r_t+1, terminal_t+1)} ~ D uniformly at random from the dataset (including terminal flags)
            observations, actions, rewards, nonterminals = self.D.sample(self.parms.batch_size, self.parms.chunk_size)  # Transitions start at time t = 0
            # Create initial belief and state for time t = 0
            init_belief, init_state = torch.zeros(self.parms.batch_size, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.batch_size, self.parms.state_size, device=self.parms.device)
            encoded_obs = bottle(self.encoder, (observations[1:], ))

            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(init_state, actions[:-1], init_belief, encoded_obs, nonterminals[:-1])
            
            # Calculate observation likelihood, reward likelihood and KL losses (for t = 0 only for latent overshooting); sum over final dims, average over batch and time (original implementation, though paper seems to miss 1/T scaling?)
            # LOSS
            observation_loss = F.mse_loss(bottle(self.observation_model, (beliefs, posterior_states)), observations[1:], reduction='none').sum((2, 3, 4)).mean(dim=(0, 1))
            kl_loss = torch.max(kl_divergence(Normal(posterior_means, posterior_std_devs), Normal(prior_means, prior_std_devs)).sum(dim=2), self.free_nats).mean(dim=(0, 1))  
            reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards[:-1], reduction='none').mean(dim=(0, 1))            

            # Update model parameters
            self.optimiser.zero_grad()

            (observation_loss + reward_loss + kl_loss).backward() # BACKPROPAGATION
            nn.utils.clip_grad_norm_(self.param_list, self.parms.grad_clip_norm, norm_type=2)
            self.optimiser.step()
            # Store (0) observation loss (1) reward loss (2) KL loss
            losses.append([observation_loss.item(), reward_loss.item(), kl_loss.item()])#, regularizer_loss.item()])

        #save statistics and plot them
        losses = tuple(zip(*losses))  
        self.metrics['observation_loss'].append(losses[0])
        self.metrics['reward_loss'].append(losses[1])
        self.metrics['kl_loss'].append(losses[2])
      
        lineplot(self.metrics['episodes'][-len(self.metrics['observation_loss']):], self.metrics['observation_loss'], 'observation_loss', self.statistics_path)
        lineplot(self.metrics['episodes'][-len(self.metrics['reward_loss']):], self.metrics['reward_loss'], 'reward_loss', self.statistics_path)
        lineplot(self.metrics['episodes'][-len(self.metrics['kl_loss']):], self.metrics['kl_loss'], 'kl_loss', self.statistics_path)
        
    def explore_and_collect(self,episode):
        tqdm.write("Collect new data:")
        reward = 0
        # Data collection
        with torch.no_grad():
            done = False
            observation, total_reward = self.env.reset(), 0
            belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device)
            t = 0
            real_rew = []
            predicted_rew = [] 
            total_steps = self.parms.max_episode_length // self.env.action_repeat
            explore = True

            for t in tqdm(range(total_steps)):
                # Here we need to explore
                belief, posterior_state, action, next_observation, reward, done, pred_next_rew = self.update_belief_and_act(self.env, belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1], explore=explore)
                self.D.append(observation, action.cpu(), reward, done)
                real_rew.append(reward)
                predicted_rew.append(pred_next_rew.to(device=self.parms.device).item())
                total_reward += reward
                observation = next_observation
                if self.parms.flag_render:
                    env.render()
                if done:
                    break

        # Update and plot train reward metrics
        self.metrics['steps'].append( (t * self.env.action_repeat) + self.metrics['steps'][-1])
        self.metrics['episodes'].append(episode)
        self.metrics['train_rewards'].append(total_reward)
        self.metrics['predicted_rewards'].append(np.array(predicted_rew).sum())

        lineplot(self.metrics['episodes'][-len(self.metrics['train_rewards']):], self.metrics['train_rewards'], 'train_rewards', self.statistics_path)
        double_lineplot(self.metrics['episodes'], self.metrics['train_rewards'], self.metrics['predicted_rewards'], "train_r_vs_pr", self.statistics_path)

    def train_models(self):
        # from (init_episodes) to (training_episodes + init_episodes)
        tqdm.write("Start training.")

        for episode in tqdm(range(self.parms.num_init_episodes +1, self.parms.training_episodes) ):
            self.fit_buffer(episode)       
            self.explore_and_collect(episode)
            if episode % self.parms.test_interval == 0:
                self.test_model(episode)
                torch.save(self.metrics, os.path.join(self.model_path, 'metrics.pth'))
                torch.save({'transition_model': self.transition_model.state_dict(), 'observation_model': self.observation_model.state_dict(), 'reward_model': self.reward_model.state_dict(), 'encoder': self.encoder.state_dict(), 'optimiser': self.optimiser.state_dict()},  os.path.join(self.model_path, 'models_%d.pth' % episode))
            
            if episode % self.parms.storing_dataset_interval == 0:
                self.D.store_dataset(self.parms.dataset_path+'dump_dataset')

        return self.metrics

    def test_model(self, episode=None): #no explore here
        if episode is None:
            episode = self.tested_episodes


        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        
        # Initialise parallelised test environments
        test_envs = EnvBatcher(ControlSuiteEnv, (self.parms.env_name, self.parms.seed, self.parms.max_episode_length, self.parms.bit_depth), {}, self.parms.test_episodes)
        total_steps = self.parms.max_episode_length // test_envs.action_repeat
        rewards = np.zeros(self.parms.test_episodes)
        
        real_rew = torch.zeros([total_steps,self.parms.test_episodes])
        predicted_rew = torch.zeros([total_steps,self.parms.test_episodes])

        with torch.no_grad():
            observation, total_rewards, video_frames = test_envs.reset(), np.zeros((self.parms.test_episodes, )), []            
            belief, posterior_state, action = torch.zeros(self.parms.test_episodes, self.parms.belief_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.parms.state_size, device=self.parms.device), torch.zeros(self.parms.test_episodes, self.env.action_size, device=self.parms.device)
            tqdm.write("Testing model.")
            for t in range(total_steps):     
                belief, posterior_state, action, next_observation, rewards, done, pred_next_rew  = self.update_belief_and_act(test_envs,  belief, posterior_state, action, observation.to(device=self.parms.device), list(rewards), self.env.action_range[0], self.env.action_range[1])
                total_rewards += rewards.numpy()
                real_rew[t] = rewards
                predicted_rew[t]  = pred_next_rew

                observation = self.env.get_original_frame().unsqueeze(dim=0)

                video_frames.append(make_grid(torch.cat([observation, self.observation_model(belief, posterior_state).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                observation = next_observation
                if done.sum().item() == self.parms.test_episodes:
                    break
            
        real_rew = torch.transpose(real_rew, 0, 1)
        predicted_rew = torch.transpose(predicted_rew, 0, 1)
        
        #save and plot metrics 
        self.tested_episodes += 1
        self.metrics['test_episodes'].append(episode)
        self.metrics['test_rewards'].append(total_rewards.tolist())

        lineplot(self.metrics['test_episodes'], self.metrics['test_rewards'], 'test_rewards', self.statistics_path)
        
        write_video(video_frames, 'test_episode_%s' % str(episode), self.video_path)  # Lossy compression
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        test_envs.close()
        return self.metrics


    def dump_plan_video(self, step_before_plan=120): 
        #number of steps before to start to collect frames to dump
        step_before_plan = min(step_before_plan, (self.parms.max_episode_length // self.env.action_repeat))
        
        # Set models to eval mode
        self.transition_model.eval()
        self.observation_model.eval()
        self.reward_model.eval()
        self.encoder.eval()
        video_frames = []
        reward = 0

        with torch.no_grad():
            observation = self.env.reset()
            belief, posterior_state, action = torch.zeros(1, self.parms.belief_size, device=self.parms.device), torch.zeros(1, self.parms.state_size, device=self.parms.device), torch.zeros(1, self.env.action_size, device=self.parms.device)
            tqdm.write("Executing episode.")
            for t in range(step_before_plan): #floor division
                belief, posterior_state, action, next_observation, reward, done, _ = self.update_belief_and_act(self.env,  belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1])
                observation = next_observation
                video_frames.append(make_grid(torch.cat([observation.cpu(), self.observation_model(belief, posterior_state).to(device=self.parms.device).cpu()], dim=3) + 0.5, nrow=5).numpy())  # Decentre
                if done:
                    break
            self.create_and_dump_plan(self.env,  belief, posterior_state, action, observation.to(device=self.parms.device), [reward], self.env.action_range[0], self.env.action_range[1])
            
            
        # Set models to train mode
        self.transition_model.train()
        self.observation_model.train()
        self.reward_model.train()
        self.encoder.train()
        # Close test environments
        self.env.close()

    def create_and_dump_plan(self, env, belief, posterior_state, action, observation, reward, min_action=-inf, max_action=inf): 

        tqdm.write("Dumping plan")
        video_frames = []

        encoded_obs = self.encoder(observation).unsqueeze(dim=0)
        belief, _, _, _, posterior_state, _, _ = self.transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoded_obs)  
        belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
        next_action,_, beliefs, states, plan = self.planner(belief, posterior_state,False)  # Get action from planner(q(s_t|o≤t,a<t), p)      
        predicted_frames = self.observation_model(beliefs, states).to(device=self.parms.device)

        for i in range(self.parms.planning_horizon):
            plan[i].clamp_(min=env.action_range[0], max=self.env.action_range[1])  # Clip action range
            next_observation, reward, done = env.step(plan[i].cpu())  
            next_observation = next_observation.squeeze(dim=0)
            video_frames.append(make_grid(torch.cat([next_observation, predicted_frames[i]], dim=1) + 0.5, nrow=2).numpy())  # Decentre

        write_video(video_frames, 'dump_plan', self.dump_plan_path, dump_frame=True)  
    
            
Example #7
0
      write_video(video_frames, 'test_episode_%s' % episode_str, results_dir)  # Lossy compression
      for i in range(len(video_frames)):
        save_image(torch.as_tensor(video_frames[i]), os.path.join(results_dir, latenSubName, 'test_episode_%s_%s.png' % (episode_str, str(i))))
      save_image(torch.as_tensor(video_frames[-1]), os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
    torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

    # Set models to train mode
    transition_model.train()
    observation_model.train()
    reward_model.train()
    encoder.train()
    # Close test environments
    test_envs.close()


  # Checkpoint models
  if episode % args.checkpoint_interval == 0:
    torch.save({'transition_model': transition_model.state_dict(), 'observation_model': observation_model.state_dict(), 'reward_model': reward_model.state_dict(), 'encoder': encoder.state_dict(), 'optimiser': optimiser.state_dict()}, os.path.join(results_dir, 'models_%d.pth' % episode))
    if args.checkpoint_experience:
      torch.save(D, os.path.join(results_dir, 'experience.pth'))  # Warning: will fail with MemoryError with large memory sizes


# Close training environment
env.close()
print ('The total Time taken is : ')
print (datetime.now()-start)

# Pickle Dump Metrics
with open(os.path.join(results_dir, 'metrics.pkl'), 'wb') as f:
    pickle.dump(mylist, f)
epochs = 1000
data_dir = '/home/lucliu/dataset/domain_adaptation/office31'
src_dir = 'amazon'
#src_dir = 'webcam'
cuda = torch.cuda.is_available() 
# dataloader
src_train_loader = get_dataloader(data_dir, src_dir, batch_size, train=True)
# model
# Pretrained Model
alexnet = torchvision.models.alexnet(pretrained=True)
pretrained_dict = alexnet.state_dict()
# Train source data
# Model parameters
src_encoder = Encoder()
src_classifier = ClassClassifier(num_classes=31)
src_encoder_dict = src_encoder.state_dict()
# Load pretrained model 
# filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in src_encoder_dict}
# overwrite entries in the existing state dict
src_encoder_dict.update(pretrained_dict) 
# load the new state dict
src_encoder.load_state_dict(src_encoder_dict)
optimizer = optim.SGD(
    list(src_encoder.parameters()) + list(src_classifier.parameters()),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay)

criterion = nn.CrossEntropyLoss()
Example #9
0
                if batch_idx % log_interval == 0:

                    print('Eval Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f} {}'.
                          format(epoch, batch_idx * len(data),
                                 len(train_loader.dataset),
                                 100. * batch_idx / len(train_loader),
                                 loss.item() / len(data),
                                 datetime.datetime.now()))

                    _, preds = torch.max(scores_copy, dim=2)
                    preds = preds.cpu().numpy()
                    targets_copy = targets_copy.cpu().numpy()
                    for i in range(4):
                        sample = preds[i, ...]
                        target = targets_copy[i, ...]
                        print("ORIG: {}\nNEW : {}\n".format(
                            "".join([charset[chars] for chars in target]),
                            "".join([charset[chars] for chars in sample])))

        experiment.log_metric("loss", losses.avg)
    return losses.avg


for epoch in range(starting_epoch, epochs):
    decoder_sched.step()
    encoder_sched.step()
    train(epoch)
    val = test(epoch)
    torch.save(encoder.state_dict(), "encoder." + str(epoch) + ".pt")
    torch.save(decoder.state_dict(), "decoder." + str(epoch) + ".pt")
Example #10
0
                os.path.join(results_dir, 'test_episode_%s.png' % episode_str))
        torch.save(metrics, os.path.join(results_dir, 'metrics.pth'))

        # Set models to train mode
        transition_model.train()
        observation_model.train()
        reward_model.train()
        encoder.train()
        # Close test environments
        test_envs.close()

    # Checkpoint models
    print("Completed episode {}".format(episode))
    if episode % args.checkpoint_interval == 0:
        print("Saving!")
        torch.save(
            {
                'transition_model': transition_model.state_dict(),
                'observation_model': observation_model.state_dict(),
                'reward_model': reward_model.state_dict(),
                'encoder': encoder.state_dict(),
                'optimiser': optimiser.state_dict()
            }, os.path.join(results_dir, 'models_%d.pth' % episode))
        if args.checkpoint_experience:
            torch.save(
                D, os.path.join(results_dir, 'experience.pth')
            )  # Warning: will fail with MemoryError with large memory sizes

# Close training environment
env.close()
        optimizer_D_B.step()

        loss_GAN_iter += loss_GAN_A2B.data + loss_GAN_B2A.data
        loss_cycle_iter += loss_cycle_ABA.data + loss_cycle_BAB.data
        loss_D_iter += loss_D_A.data + loss_D_B.data
        print(
            'epoch [%d/%d], iteration [%d/%d], loss_G: %.3f, loss_G_GAN: %.3f, loss_G_cycle: %.3f, loss_D: %.3f'
            % (epoch + 1, n_epochs, i + 1, len(dataloader), loss_G,
               loss_GAN_A2B + loss_GAN_B2A, loss_cycle_ABA + loss_cycle_BAB,
               loss_D_A + loss_D_B))

    # Update learning rates
    lr_scheduler_G.step()
    lr_scheduler_D_A.step()
    lr_scheduler_D_B.step()

    loss_GAN_hist.append(loss_GAN_iter / len(dataloader))
    loss_cycle_hist.append(loss_cycle_iter / len(dataloader))
    loss_D_hist.append(loss_cycle_iter / len(dataloader))

    # Save models checkpoints
    if not os.path.exists('output modified'):
        os.makedirs('output modified')
    torch.save(encoder.state_dict(), 'output modified/encoder.pth')
    torch.save(decoder_A2B.state_dict(), 'output modified/encoder_A2B.pth')
    torch.save(decoder_B2A.state_dict(), 'output modified/encoder_B2A.pth')
    torch.save(netD_A.state_dict(), 'output modified/netD_A.pth')
    torch.save(netD_B.state_dict(), 'output3/netD_B.pth')
    ###################################
Example #12
0
class Trainer:
    def __init__(self, device, dset, x_dim, c_dim, z_dim, n_train, n_test, lr,
                 layer_sizes, **kwargs):
        '''
        Trainer class
        Args:
            device (torch.device) : Use GPU or CPU
            x_dim (int)           : Feature dimension
            c_dim (int)           : Attribute dimension
            z_dim (int)           : Latent dimension
            n_train (int)         : Number of training classes
            n_test (int)          : Number of testing classes
            lr (float)            : Learning rate for VAE
            layer_sizes(dict)     : List containing the hidden layer sizes
            **kwargs              : Flags for using various regularizations
        '''
        self.device = device
        self.dset = dset
        self.lr = lr
        self.z_dim = z_dim

        self.n_train = n_train
        self.n_test = n_test
        self.gzsl = kwargs.get('gzsl', False)
        if self.gzsl:
            self.n_test = n_train + n_test

        # flags for various regularizers
        self.use_da = kwargs.get('use_da', False)
        self.use_ca = kwargs.get('use_ca', False)
        self.use_support = kwargs.get('use_support', False)

        self.x_encoder = Encoder(x_dim, layer_sizes['x_enc'],
                                 z_dim).to(self.device)
        self.x_decoder = Decoder(z_dim, layer_sizes['x_dec'],
                                 x_dim).to(self.device)

        self.c_encoder = Encoder(c_dim, layer_sizes['c_enc'],
                                 z_dim).to(self.device)
        self.c_decoder = Decoder(z_dim, layer_sizes['c_dec'],
                                 c_dim).to(self.device)

        self.support_classifier = Classifier(z_dim,
                                             self.n_train).to(self.device)

        params = list(self.x_encoder.parameters()) + \
                 list(self.x_decoder.parameters()) + \
                 list(self.c_encoder.parameters()) + \
                 list(self.c_decoder.parameters())

        if self.use_support:
            params += list(self.support_classifier.parameters())

        self.optimizer = optim.Adam(params, lr=lr)

        self.final_classifier = Classifier(z_dim, self.n_test).to(self.device)
        self.final_cls_optim = optim.RMSprop(
            self.final_classifier.parameters(), lr=2e-4)
        self.criterion = nn.CrossEntropyLoss()

        self.vae_save_path = './saved_models'
        self.disc_save_path = './saved_models/disc_model_%s.pth' % self.dset

    def fit_VAE(self, x, c, y, ep):
        '''
        Train on 1 minibatch of data
        Args:
            x (torch.Tensor) : Features of size (batch_size, 2048)
            c (torch.Tensor) : Attributes of size (batch_size, attr_dim)
            y (torch.Tensor) : Target labels of size (batch_size,)
            ep (int)         : Epoch number
        Returns:
            Loss for the minibatch -
            3-tuple with (vae_loss, distributn loss, cross_recon loss)
        '''
        self.anneal_parameters(ep)

        x = Variable(x.float()).to(self.device)
        c = Variable(c.float()).to(self.device)
        y = Variable(y.long()).to(self.device)

        # VAE for image embeddings
        mu_x, logvar_x = self.x_encoder(x)
        z_x = self.reparameterize(mu_x, logvar_x)
        x_recon = self.x_decoder(z_x)

        # VAE for class embeddings
        mu_c, logvar_c = self.c_encoder(c)
        z_c = self.reparameterize(mu_c, logvar_c)
        c_recon = self.c_decoder(z_c)

        # reconstruction loss
        L_recon_x = self.compute_recon_loss(x, x_recon)
        L_recon_c = self.compute_recon_loss(c, c_recon)

        # KL divergence loss
        D_kl_x = self.compute_kl_div(mu_x, logvar_x)
        D_kl_c = self.compute_kl_div(mu_c, logvar_c)

        # VAE Loss = recon_loss - KL_Divergence_loss
        L_vae_x = L_recon_x - self.beta * D_kl_x
        L_vae_c = L_recon_c - self.beta * D_kl_c
        L_vae = L_vae_x + L_vae_c

        # calculate cross alignment loss
        L_ca = torch.zeros(1).to(self.device)
        if self.use_ca:
            x_recon_from_c = self.x_decoder(z_c)
            L_ca_x = self.compute_recon_loss(x, x_recon_from_c)

            c_recon_from_x = self.c_decoder(z_x)
            L_ca_c = self.compute_recon_loss(c, c_recon_from_x)

            L_ca = L_ca_x + L_ca_c

        # calculate distribution alignment loss
        L_da = torch.zeros(1).to(self.device)
        if self.use_da:
            L_da = 2 * self.compute_da_loss(mu_x, logvar_x, mu_c, logvar_c)

        # calculate loss from support classifier
        L_sup = torch.zeros(1).to(self.device)
        if self.use_support:
            y_prob = F.softmax(self.support_classifier(z_x), dim=0)
            log_prob = torch.log(torch.gather(y_prob, 1, y.unsqueeze(1)))
            L_sup = -1 * torch.mean(log_prob)

        total_loss = L_vae + self.gamma * L_ca + self.delta * L_da + self.alpha * L_sup

        self.optimizer.zero_grad()
        total_loss.backward()
        self.optimizer.step()

        return L_vae.item(), L_da.item(), L_ca.item()

    def reparameterize(self, mu, log_var):
        '''
        Reparameterization trick using unimodal gaussian
        '''
        # eps = Variable(torch.randn(mu.size())).to(self.device)
        eps = Variable(torch.randn(mu.size()[0],
                                   1).expand(mu.size())).to(self.device)
        z = mu + torch.exp(log_var / 2.0) * eps
        return z

    def anneal_parameters(self, epoch):
        '''
        Change weight factors of various losses based on epoch number
        '''
        # weight of kl divergence loss
        if epoch <= 90:
            self.beta = 0.0026 * epoch

        # weight of Cross Alignment loss
        if epoch < 20:
            self.gamma = 0
        if epoch >= 20 and epoch <= 75:
            self.gamma = 0.044 * (epoch - 20)

        # weight of distribution alignment loss
        if epoch < 5:
            self.delta = 0
        if epoch >= 5 and epoch <= 22:
            self.delta = 0.54 * (epoch - 5)

        # weight of support loss
        if epoch < 5:
            self.alpha = 0
        else:
            self.alpha = 0.01

    def compute_recon_loss(self, x, x_recon):
        '''
        Compute the reconstruction error.
        '''
        l1_loss = torch.abs(x - x_recon).sum()
        # l1_loss = torch.abs(x - x_recon).sum(dim=1).mean()
        return l1_loss

    def compute_kl_div(self, mu, log_var):
        '''
        Compute KL Divergence between N(mu, var) & N(0, 1).
        '''
        kld = 0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum()
        # kld = 0.5 * (1 + log_var - mu.pow(2) - log_var.exp()).sum(dim=1).mean()
        return kld

    def compute_da_loss(self, mu1, log_var1, mu2, log_var2):
        '''
        Computes Distribution Alignment loss between 2 normal distributions.
        Uses Wasserstein distance as distance measure.
        '''
        l1 = (mu1 - mu2).pow(2).sum(dim=1)

        std1 = (log_var1 / 2.0).exp()
        std2 = (log_var2 / 2.0).exp()
        l2 = (std1 - std2).pow(2).sum(dim=1)

        l_da = torch.sqrt(l1 + l2).sum()
        return l_da

    def fit_final_classifier(self, x, y):
        '''
        Train the final classifier on synthetically generated data
        '''
        x = Variable(x.float()).to(self.device)
        y = Variable(y.long()).to(self.device)

        logits = self.final_classifier(x)
        loss = self.criterion(logits, y)

        self.final_cls_optim.zero_grad()
        loss.backward()
        self.final_cls_optim.step()

        return loss.item()

    def fit_MOE(self, x, y):
        '''
        Trains the synthetic dataset on a MoE model
        '''

    def get_vae_savename(self):
        '''
        Returns a string indicative of various flags used during training and
        dataset used. Works as a unique name for saving models
        '''
        flags = ''
        if self.use_da:
            flags += '-da'
        if self.use_ca:
            flags += '-ca'
        if self.use_support:
            flags += '-support'
        model_name = 'vae_model__dset-%s__lr-%f__z-%d__%s.pth' % (
            self.dset, self.lr, self.z_dim, flags)
        return model_name

    def save_VAE(self, ep):
        state = {
            'epoch': ep,
            'x_encoder': self.x_encoder.state_dict(),
            'x_decoder': self.x_decoder.state_dict(),
            'c_encoder': self.c_encoder.state_dict(),
            'c_decoder': self.c_decoder.state_dict(),
            'optimizer': self.optimizer.state_dict(),
        }
        model_name = self.get_vae_savename()
        torch.save(state, os.path.join(self.vae_save_path, model_name))

    def load_models(self, model_path=''):
        if model_path is '':
            model_path = os.path.join(self.vae_save_path,
                                      self.get_vae_savename())

        ep = 0
        if os.path.exists(model_path):
            checkpoint = torch.load(model_path)
            self.x_encoder.load_state_dict(checkpoint['x_encoder'])
            self.x_decoder.load_state_dict(checkpoint['x_decoder'])
            self.c_encoder.load_state_dict(checkpoint['c_encoder'])
            self.c_decoder.load_state_dict(checkpoint['c_decoder'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            ep = checkpoint['epoch']

        return ep

    def create_syn_dataset(self,
                           test_labels,
                           attributes,
                           seen_dataset,
                           n_samples=400):
        '''
        Creates a synthetic dataset based on attribute vectors of unseen class
        Args:
            test_labels: A dict with key as original serial number in provided
                dataset and value as the index which is predicted during
                classification by network
            attributes: A np array containing class attributes for each class
                of dataset
            seen_dataset: A list of 3-tuple (x, _, y) where x belongs to one of the
                seen classes and y is corresponding label. Used for generating
                latent representations of seen classes in GZSL
            n_samples: Number of samples of each unseen class to be generated(Default: 400)
        Returns:
            A list of 3-tuple (z, _, y) where z is latent representations and y is
            corresponding label
        '''
        syn_dataset = []
        for test_cls, idx in test_labels.items():
            attr = attributes[test_cls - 1]

            self.c_encoder.eval()
            c = Variable(torch.FloatTensor(attr).unsqueeze(0)).to(self.device)
            mu, log_var = self.c_encoder(c)

            Z = torch.cat(
                [self.reparameterize(mu, log_var) for _ in range(n_samples)])

            syn_dataset.extend([(Z[i], test_cls, idx)
                                for i in range(n_samples)])

        if seen_dataset is not None:
            self.x_encoder.eval()
            for (x, att_idx, y) in seen_dataset:
                x = Variable(torch.FloatTensor(x).unsqueeze(0)).to(self.device)
                mu, log_var = self.x_encoder(x)
                z = self.reparameterize(mu, log_var).squeeze()
                syn_dataset.append((z, att_idx, y))

        return syn_dataset

    def compute_accuracy(self, generator):
        y_real_list, y_pred_list = [], []

        for idx, (x, _, y) in enumerate(generator):
            x = Variable(x.float()).to(self.device)
            y = Variable(y.long()).to(self.device)

            self.final_classifier.eval()
            self.x_encoder.eval()
            mu, log_var = self.x_encoder(x)
            logits = self.final_classifier(mu)

            _, y_pred = logits.max(dim=1)

            y_real = y.detach().cpu().numpy()
            y_pred = y_pred.detach().cpu().numpy()

            y_real_list.extend(y_real)
            y_pred_list.extend(y_pred)

        ## We have sequence of real and predicted labels
        ## find seen and unseen classes accuracy

        if self.gzsl:
            y_real_list = np.asarray(y_real_list)
            y_pred_list = np.asarray(y_pred_list)

            y_seen_real = np.extract(y_real_list < self.n_train, y_real_list)
            y_seen_pred = np.extract(y_real_list < self.n_train, y_pred_list)

            y_unseen_real = np.extract(y_real_list >= self.n_train,
                                       y_real_list)
            y_unseen_pred = np.extract(y_real_list >= self.n_train,
                                       y_pred_list)

            acc_seen = accuracy_score(y_seen_real, y_seen_pred)
            acc_unseen = accuracy_score(y_unseen_real, y_unseen_pred)

            return acc_seen, acc_unseen

        else:
            return accuracy_score(y_real_list, y_pred_list)
#     print("Test set results:",
#           "loss= {:.4f}".format(loss_test.data[0]),
#           "accuracy= {:.4f}".format(acc_test.data[0]))

if __name__ == "__main__":
    # Train model
    t_total = time.time()
    loss_values = []
    bad_counter = 0
    best = args.epochs + 1
    best_epoch = 0
    logger.info('Trainer Built')
    for epoch in range(args.epochs):
        loss_values.append(train(epoch, train_loader, val_loader, logging))

        torch.save(model.state_dict(), os.path.join(args.output_path, '{}.pkl'.format(epoch)))
        if loss_values[-1] < best:
            best = loss_values[-1]
            best_epoch = epoch
            bad_counter = 0
        else:
            bad_counter += 1

        if bad_counter == args.patience:
            break

        files = glob.glob(os.path.join(args.output_path, '*.pkl'))
        for file in files:
            epoch_nb = int(file.split('.')[0])
            if epoch_nb < best_epoch:
                os.remove(file)
class DQNAgent():
    """Interacts with and learns from the environment."""
    def __init__(self, state_size, action_size, config):
        """Initialize an Agent object.
        
        Params
        ======
            state_size (int): dimension of each state
            action_size (int): dimension of each action
            seed (int): random seed
        """
        self.state_size = state_size
        self.action_size = action_size
        self.seed = random.seed(config["seed"])
        self.seed = config["seed"]
        self.gamma = 0.99
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]
        self.tau = config["tau"]
        self.fc1 = config["fc1_units"]
        self.fc2 = config["fc2_units"]
        self.device = config["device"]
        # Q-Network
        self.qnetwork_local = QNetwork(state_size, action_size, self.fc1,
                                       self.fc2, self.seed).to(self.device)
        self.qnetwork_target = QNetwork(state_size, action_size, self.fc1,
                                        self.fc2, self.seed).to(self.device)

        self.optimizer = optim.Adam(self.qnetwork_local.parameters(),
                                    lr=self.lr)
        self.encoder = Encoder(config).to(self.device)
        self.encoder_optimizer = torch.optim.Adam(self.encoder.parameters(),
                                                  self.lr)

        # Replay memory

        # Initialize time step (for updating every UPDATE_EVERY steps)
        self.t_step = 0

    def step(self, memory, writer):
        self.t_step += 1
        if len(memory) > self.batch_size:
            if self.t_step % 4 == 0:
                experiences = memory.sample(self.batch_size)
                self.learn(experiences, writer)

    def act(self, state, eps=0.):
        """Returns actions for given state as per current policy.
        
        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
        state = state.type(torch.float32).div_(255)
        self.qnetwork_local.eval()
        self.encoder.eval()
        with torch.no_grad():
            state = self.encoder.create_vector(state)
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()
        self.encoder.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def learn(self, experiences, writer):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples 
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences
        states = states.type(torch.float32).div_(255)
        states = self.encoder.create_vector(states)
        next_states = next_states.type(torch.float32).div_(255)
        next_states = self.encoder.create_vector(next_states)
        actions = actions.type(torch.int64)
        # Get max predicted Q values (for next states) from target model
        Q_targets_next = self.qnetwork_target(next_states).detach().max(
            1)[0].unsqueeze(1)
        # Compute Q targets for current states
        Q_targets = rewards + (self.gamma * Q_targets_next * dones)

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        writer.add_scalar('Q_loss', loss, self.t_step)
        # Minimize the loss
        self.optimizer.zero_grad()
        self.encoder_optimizer.zero_grad()

        loss.backward()
        self.optimizer.step()
        self.encoder_optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)

    def soft_update(self, local_model, target_model):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(self.tau * local_param.data +
                                    (1.0 - self.tau) * target_param.data)

    def save(self, filename):
        """
        """
        mkdir("", filename)
        torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth")
        torch.save(self.optimizer.state_dict(),
                   filename + "_q_net_optimizer.pth")
        torch.save(self.encoder.state_dict(), filename + "_encoder.pth")
        torch.save(self.encoder_optimizer.state_dict(),
                   filename + "_encoder_optimizer.pth")
        print("Save models to {}".format(filename))
Example #15
0
        output = torch.cat((good_pairs, bad_pairs))
        target = torch.cat((good_target, bad_target))

        # print(output[:10], output[-10:])
        loss = loss_function(output, target)
        # print(loss.item())
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    print(epoch_loss)

    intloss = int(epoch_loss * 10000) / 10000
    if epoch % config.save_frequency == 0:
        torch.save(
            encoder.state_dict(),
            f'{config.saved_models_folder}/encoder_epoch{epoch}_loss{intloss}.pth'
        )
        torch.save(
            siamese_network.state_dict(),
            f'{config.saved_models_folder}/siamese_network_epoch{epoch}_loss{intloss}.pth'
        )
        print('Saved models, epoch: ' + str(epoch))

for batch in train_data_loader:
    batch = batch[0]
    break
batch = batch.to(device)
batch = transform2(batch)

print('test')
Example #16
0
        # # src_output_dm_conf = dm_classifier(src_feature)
        # # tgt_output_dm_conf = dm_classifier(tgt_feature)
        # output_dm_conf = dm_classifier(feature_concat)
        # uni_distrib = torch.FloatTensor(output_dm_conf.size()).uniform_(0, 1)
        # if cuda:
        #     uni_distrib = uni_distrib.cuda()
        # uni_distrib = Variable(uni_distrib)
        # # loss_conf = lam * criterion_kl(tgt_output_dm_conf, uni_distrib)
        # loss_conf = - lam * (torch.sum(uni_distrib * torch.log(output_dm_conf)))/float(output_dm_conf.size(0))
        # loss_conf.backward()
        # optimizer_conf.step()
        # acc
        tgt_output_cl_score = F.softmax(tgt_output_cl, dim=1)  # softmax first
        pred = tgt_output_cl_score.data.max(1, keepdim=True)[1]
        correct += pred.eq(tgt_label_cl.data.view_as(pred)).cpu().sum()

    acc = correct / len(tgt_train_loader.dataset)
    print("epoch: %d, class loss: %f, acc: %f" % (epoch, loss.data[0], acc))

    # save parameters
    if (epoch % interval == 0):
        torch.save(encoder.state_dict(),
                   "./checkpoints/a2w/no_soft_encoder{}.pth".format(epoch))
        torch.save(
            cl_classifier.state_dict(),
            "./checkpoints/a2d/no_soft_class_classifier{}.pth".format(epoch))

torch.save(encoder.state_dict(), "./checkpoints/a2w/no_soft_encoder_final.pth")
torch.save(cl_classifier.state_dict(),
           "./checkpoints/a2d/no_soft_class_classifier_final.pth")
Example #17
0
def train_CIFAR10(opt):

    import torchvision.datasets as datasets
    import torchvision.transforms as transforms
    from torchvision.utils import make_grid
    from matplotlib import pyplot as plt
    params = get_config(opt.config)

    save_path = os.path.join(
        params['save_path'],
        datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(save_path, exist_ok=True)
    shutil.copy('models.py', os.path.join(save_path, 'models.py'))
    shutil.copy('train.py', os.path.join(save_path, 'train.py'))
    shutil.copy(opt.config,
                os.path.join(save_path, os.path.basename(opt.config)))

    cuda = torch.cuda.is_available()
    gpu_ids = [i for i in range(torch.cuda.device_count())]

    TensorType = torch.cuda.FloatTensor if cuda else torch.Tensor

    data_path = os.path.join(params['data_root'], 'cifar10')

    os.makedirs(data_path, exist_ok=True)

    train_dataset = datasets.CIFAR10(root=data_path,
                                     train=True,
                                     download=True,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    val_dataset = datasets.CIFAR10(root=data_path,
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]))

    train_loader = DataLoader(train_dataset,
                              batch_size=params['batch_size'] * len(gpu_ids),
                              shuffle=True,
                              num_workers=params['num_workers'],
                              pin_memory=cuda)
    val_loader = DataLoader(val_dataset,
                            batch_size=1,
                            num_workers=params['num_workers'],
                            pin_memory=cuda)

    data_variance = np.var(train_dataset.train_data / 255.0)

    encoder = Encoder(params['dim'], params['residual_channels'],
                      params['n_layers'], params['d'])
    decoder = Decoder(params['dim'], params['residual_channels'],
                      params['n_layers'], params['d'])

    vq = VectorQuantizer(params['k'], params['d'], params['beta'],
                         params['decay'], TensorType)

    if params['checkpoint'] != None:
        checkpoint = torch.load(params['checkpoint'])

        params['start_epoch'] = checkpoint['epoch']
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        vq.load_state_dict(checkpoint['vq'])

    model = VQVAE(encoder, decoder, vq)

    if cuda:
        model = nn.DataParallel(model.cuda(), device_ids=gpu_ids)

    parameters = list(model.parameters())
    opt = torch.optim.Adam([p for p in parameters if p.requires_grad],
                           lr=params['lr'])

    for epoch in range(params['start_epoch'], params['num_epochs']):
        train_bar = tqdm(train_loader)
        for data, _ in train_bar:
            if cuda:
                data = data.cuda()
            opt.zero_grad()

            vq_loss, data_recon, _ = model(data)
            recon_error = torch.mean((data_recon - data)**2) / data_variance
            loss = recon_error + vq_loss.mean()
            loss.backward()
            opt.step()

            train_bar.set_description('Epoch {}: loss {:.4f}'.format(
                epoch + 1,
                loss.mean().item()))

        model.eval()
        data_val = next(iter(val_loader))
        data_val, _ = data_val

        if cuda:
            data_val = data_val.cuda()
        _, data_recon_val, _ = model(data_val)

        plt.imsave(os.path.join(save_path, 'latest_val_recon.png'),
                   (make_grid(data_recon_val.cpu().data) +
                    0.5).numpy().transpose(1, 2, 0))
        plt.imsave(os.path.join(save_path, 'latest_val_orig.png'),
                   (make_grid(data_val.cpu().data) + 0.5).numpy().transpose(
                       1, 2, 0))

        model.train()

        torch.save(
            {
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'vq': vq.state_dict(),
            }, os.path.join(save_path, '{}_checkpoint.pth'.format(epoch)))
Example #18
0
]
encoder = Encoder(args, input_size=English.num_words)
decoder = Decoder(args, output_size=Vietnamese.num_words)

args.output_size = Vietnamese.num_words
if 'train' in args.mode:
    if args.checkpoint is True:
        encoder.load_state_dict(
            torch.load(os.path.join('model', 'encoder.pkl')))
        decoder.load_state_dict(
            torch.load(os.path.join('model', 'decoder.pkl')))
        encoder.eval()
        decoder.eval()
    utils.train_handler(args, encoder, decoder, English_tokens,
                        Vietnamese_tokens, English, Vietnamese)
    torch.save(encoder.state_dict(), os.path.join('model', 'encoder.pkl'))
    torch.save(decoder.state_dict(), os.path.join('model', 'decoder.pkl'))

elif 'translate' in args.mode:
    encoder.load_state_dict(torch.load(os.path.join('model', 'encoder.pkl')))
    decoder.load_state_dict(torch.load(os.path.join('model', 'decoder.pkl')))
    encoder.eval()
    decoder.eval()
    while True:
        print("Please Enter Input...")
        sentence = input()
        utils.translate(args, sentence, encoder, decoder, English, Vietnamese)
elif 'test' in args.mode:
    encoder.load_state_dict(torch.load(os.path.join('model', 'encoder.pkl')))
    decoder.load_state_dict(torch.load(os.path.join('model', 'decoder.pkl')))
    encoder.eval()
        uni_distrib = torch.FloatTensor(output_dm_conf.size()).uniform_(0, 1)
        if cuda:
            uni_distrib = uni_distrib.cuda()
        uni_distrib = Variable(uni_distrib)
        # loss_conf = lam * criterion_kl(tgt_output_dm_conf, uni_distrib)
        loss_conf = -lam * (torch.sum(
            uni_distrib * torch.log(output_dm_conf))) / float(
                output_dm_conf.size(0))
        loss_conf.backward()
        optimizer_conf.step()
        # acc
        tgt_output_cl_score = F.softmax(tgt_output_cl, dim=1)  # softmax first
        pred = tgt_output_cl_score.data.max(1, keepdim=True)[1]
        correct += pred.eq(tgt_label_cl.data.view_as(pred)).cpu().sum()

    acc = correct / len(tgt_train_loader.dataset)
    print(
        "epoch: %d, class loss: %f, domain loss: %f, confusion loss: %f, acc: %f"
        % (epoch, loss.data[0], loss_dm.data[0], loss_conf.data[0], acc))

    # save parameters
    if (epoch % interval == 0):
        torch.save(encoder.state_dict(),
                   "./checkpoints/a2d/encoder{}.pth".format(epoch))
        torch.save(cl_classifier.state_dict(),
                   "./checkpoints/a2d/class_classifier{}.pth".format(epoch))

torch.save(encoder.state_dict(), "./checkpoints/a2d/encoder_final.pth")
torch.save(cl_classifier.state_dict(),
           "./checkpoints/a2d/class_classifier_final.pth")
        # # update encoder only using domain loss
        # optimizer_conf.zero_grad()
        # feature_concat =  torch.cat((src_feature, tgt_feature), 0)
        # # src_output_dm_conf = dm_classifier(src_feature)
        # # tgt_output_dm_conf = dm_classifier(tgt_feature)
        # output_dm_conf = dm_classifier(feature_concat)
        # uni_distrib = torch.FloatTensor(output_dm_conf.size()).uniform_(0, 1)
        # if cuda:
        #     uni_distrib = uni_distrib.cuda()
        # uni_distrib = Variable(uni_distrib)
        # # loss_conf = lam * criterion_kl(tgt_output_dm_conf, uni_distrib)
        # loss_conf = - lam * (torch.sum(uni_distrib * torch.log(output_dm_conf)))/float(output_dm_conf.size(0)) 
        # loss_conf.backward()
        # optimizer_conf.step()
        # acc
        tgt_output_cl_score = F.softmax(tgt_output_cl, dim=1) # softmax first
        pred = tgt_output_cl_score.data.max(1, keepdim=True)[1]
        correct += pred.eq(tgt_label_cl.data.view_as(pred)).cpu().sum()

    acc = correct / len(tgt_train_loader.dataset)
    print("epoch: %d, class loss: %f, acc: %f"%(epoch, loss.data[0], acc))

        # save parameters
    if (epoch % interval == 0):
        torch.save(encoder.state_dict(), "./checkpoints/a2d/no_soft_encoder{}.pth".format(epoch))
        torch.save(cl_classifier.state_dict(), "./checkpoints/a2d/no_soft_class_classifier{}.pth".format(epoch))

torch.save(encoder.state_dict(), "./checkpoints/a2d/no_soft_encoder_final.pth")
torch.save(cl_classifier.state_dict(), "./checkpoints/a2d/no_soft_class_classifier_final.pth")

Example #21
0
                "epoch {} iter {}; loss_G: {:.4f}; loss_D: {:.4f}; latent: {:.4f}; KLD: {:.4f}"
                .format(epoch_id, idx, loss_G.item(),
                        vae_D_loss.item() + clr_D_loss.item(),
                        latent_loss.item(), kld_loss.item()))

            if (idx + 1) % (len(loader) // 5) == 0:
                # -------------------------------
                #  Save model
                # ------------------------------
                path = os.path.join(
                    checkpoints_path,
                    'bicycleGAN_epoch_' + str(epoch_id) + '_' + str(idx))
                torch.save(
                    {
                        'epoch': epoch_id,
                        'encoder_state_dict': encoder.state_dict(),
                        'generator_state_dict': generator.state_dict(),
                        'discriminator_state_dict': discriminator.state_dict(),
                        'optimizer_E': optimizer_E.state_dict(),
                        'optimizer_G': optimizer_G.state_dict(),
                        'optimizer_D': optimizer_D.state_dict()
                    }, path)

                # -------------------------------
                #  Visualization
                # ------------------------------
            if (idx + 1) % 100 == 0:
                vis_fake_B_encoded = denorm(
                    fake_B_encoded[0].detach()).cpu().data.numpy().astype(
                        np.uint8)
                vis_fake_B_random = denorm(
Example #22
0
def train(args):
    cfg_from_file(args.cfg)
    cfg.WORKERS = args.num_workers
    pprint.pprint(cfg)
    # set the seed manually
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)
    # define outputer
    outputer_train = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                              cfg.IMAGETEXT.SAVE_EVERY)
    outputer_val = Outputer(args.output_dir, cfg.IMAGETEXT.PRINT_EVERY,
                            cfg.IMAGETEXT.SAVE_EVERY)
    # define the dataset
    split_dir, bshuffle = 'train', True

    # Get data loader
    imsize = cfg.TREE.BASE_SIZE * (2**(cfg.TREE.BRANCH_NUM - 1))
    train_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.RandomCrop(imsize),
    ])
    val_transform = transforms.Compose([
        transforms.Scale(int(imsize * 76 / 64)),
        transforms.CenterCrop(imsize),
    ])
    if args.dataset == 'bird':
        train_dataset = ImageTextDataset(args.data_dir,
                                         split_dir,
                                         transform=train_transform,
                                         sample_type='train')
        val_dataset = ImageTextDataset(args.data_dir,
                                       'val',
                                       transform=val_transform,
                                       sample_type='val')
    elif args.dataset == 'coco':
        train_dataset = CaptionDataset(args.data_dir,
                                       split_dir,
                                       transform=train_transform,
                                       sample_type='train',
                                       coco_data_json=args.coco_data_json)
        val_dataset = CaptionDataset(args.data_dir,
                                     'val',
                                     transform=val_transform,
                                     sample_type='val',
                                     coco_data_json=args.coco_data_json)
    else:
        raise NotImplementedError

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=bshuffle,
        num_workers=int(cfg.WORKERS))
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=cfg.IMAGETEXT.BATCH_SIZE,
        shuffle=False,
        num_workers=1)
    # define the model and optimizer
    if args.raw_checkpoint != '':
        encoder, decoder = load_raw_checkpoint(args.raw_checkpoint)
    else:
        encoder = Encoder()
        decoder = DecoderWithAttention(
            attention_dim=cfg.IMAGETEXT.ATTENTION_DIM,
            embed_dim=cfg.IMAGETEXT.EMBED_DIM,
            decoder_dim=cfg.IMAGETEXT.DECODER_DIM,
            vocab_size=train_dataset.n_words)
        # load checkpoint
        if cfg.IMAGETEXT.CHECKPOINT != '':
            outputer_val.log("load model from: {}".format(
                cfg.IMAGETEXT.CHECKPOINT))
            encoder, decoder = load_checkpoint(encoder, decoder,
                                               cfg.IMAGETEXT.CHECKPOINT)

    encoder.fine_tune(False)
    # to cuda
    encoder = encoder.cuda()
    decoder = decoder.cuda()
    loss_func = torch.nn.CrossEntropyLoss()
    if args.eval:  # eval only
        outputer_val.log("only eval the model...")
        assert cfg.IMAGETEXT.CHECKPOINT != ''
        val_rtn_dict, outputer_val = validate_one_epoch(
            0, val_dataloader, encoder, decoder, loss_func, outputer_val)
        outputer_val.log("\n[valid]: {}\n".format(dict2str(val_rtn_dict)))
        return

    # define optimizer
    optimizer_encoder = torch.optim.Adam(encoder.parameters(),
                                         lr=cfg.IMAGETEXT.ENCODER_LR)
    optimizer_decoder = torch.optim.Adam(decoder.parameters(),
                                         lr=cfg.IMAGETEXT.DECODER_LR)
    encoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_encoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    decoder_lr_scheduler = torch.optim.lr_scheduler.StepLR(
        optimizer_decoder, step_size=10, gamma=cfg.IMAGETEXT.LR_GAMMA)
    print("train the model...")
    for epoch_idx in range(cfg.IMAGETEXT.EPOCH):
        # val_rtn_dict, outputer_val = validate_one_epoch(epoch_idx, val_dataloader, encoder,
        #         decoder, loss_func, outputer_val)
        # outputer_val.log("\n[valid] epoch: {}, {}".format(epoch_idx, dict2str(val_rtn_dict)))
        train_rtn_dict, outputer_train = train_one_epoch(
            epoch_idx, train_dataloader, encoder, decoder, optimizer_encoder,
            optimizer_decoder, loss_func, outputer_train)
        # adjust lr scheduler
        encoder_lr_scheduler.step()
        decoder_lr_scheduler.step()

        outputer_train.log("\n[train] epoch: {}, {}\n".format(
            epoch_idx, dict2str(train_rtn_dict)))
        val_rtn_dict, outputer_val = validate_one_epoch(
            epoch_idx, val_dataloader, encoder, decoder, loss_func,
            outputer_val)
        outputer_val.log("\n[valid] epoch: {}, {}\n".format(
            epoch_idx, dict2str(val_rtn_dict)))

        outputer_val.save_step({
            "encoder": encoder.state_dict(),
            "decoder": decoder.state_dict()
        })
    outputer_val.save({
        "encoder": encoder.state_dict(),
        "decoder": decoder.state_dict()
    })
Example #23
0
                iteration,
            )
            writer.add_scalar(
                "heuristic_discriminator_average_20_obs",
                collector_heuristic_discriminator.mean(),
                iteration,
            )
            writer.add_scalar("codes_min_over_20_obs",
                              collector_codes_min.min(), iteration)
            writer.add_scalar("codes_max_over_20_obs",
                              collector_codes_max.max(), iteration)

            if iteration % (knobs["time_to_collect"] * 4) == 0:

                it_encoder_parameters = encoder.parameters()
                for k, v in encoder.state_dict().items():
                    if k.find("bias") != -1 or k.find("weight") != -1:
                        writer.add_histogram("encoder/" + k.replace(".", "/"),
                                             v, iteration)
                        writer.add_histogram(
                            "encoder/" + k.replace(".", "/") + "/grad",
                            next(it_encoder_parameters).grad,
                            iteration,
                        )
                it_decoder_parameters = decoder.parameters()
                for k, v in decoder.state_dict().items():
                    if k.find("bias") != -1 or k.find("weight") != -1:
                        writer.add_histogram("decoder/" + k.replace(".", "/"),
                                             v, iteration)
                        writer.add_histogram(
                            "decoder/" + k.replace(".", "/") + "/grad",
Example #24
0
File: train.py Project: nnuq/tpu
def main(index, args):
    device = xm.xla_device()

    gen_net = Generator(args).to(device)
    dis_net = Discriminator(args).to(device)
    enc_net = Encoder(args).to(device)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)
    enc_net.apply(weights_init)

    ae_recon_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_recon_lr, (args.beta1, args.beta2))
    ae_reg_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_reg_lr, (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(dis_net.parameters(), args.d_lr,
                                     (args.beta1, args.beta2))
    gen_optimizer = torch.optim.Adam(gen_net.parameters(), args.g_lr,
                                     (args.beta1, args.beta2))

    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    valid_loader = dataset.valid
    para_loader = pl.ParallelLoader(train_loader, [device])

    fid_stat = str(pathlib.Path(
        __file__).parent.absolute()) + '/fid_stat/fid_stat_cifar10_test.npz'
    if not os.path.exists(fid_stat):
        download_stat_cifar10_test()

    is_best = True
    args.num_epochs = np.ceil(args.num_iter / len(train_loader))

    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    ae_recon_scheduler = LinearLrDecay(ae_recon_optimizer, args.ae_recon_lr, 0,
                                       args.num_iter / 2, args.num_iter)
    ae_reg_scheduler = LinearLrDecay(ae_reg_optimizer, args.ae_reg_lr, 0,
                                     args.num_iter / 2, args.num_iter)

    # initial
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        enc_net.load_state_dict(checkpoint['enc_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ae_recon_optimizer.load_state_dict(checkpoint['ae_recon_optimizer'])
        ae_reg_optimizer.load_state_dict(checkpoint['ae_reg_optimizer'])
        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        logs_dir = str(pathlib.Path(__file__).parent.parent) + '/logs'
        args.path_helper = set_log_dir(logs_dir, args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.num_epochs)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler, dis_scheduler, ae_recon_scheduler,
                         ae_reg_scheduler)
        train(device, args, gen_net, dis_net, enc_net, gen_optimizer,
              dis_optimizer, ae_recon_optimizer, ae_reg_optimizer, para_loader,
              epoch, writer_dict, lr_schedulers)
        if epoch and epoch % args.val_freq == 0 or epoch == args.num_epochs - 1:
            fid_score = validate(args, fid_stat, gen_net, writer_dict,
                                 valid_loader)
            logger.info(f'FID score: {fid_score} || @ epoch {epoch}.')
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'enc_state_dict': enc_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'ae_recon_optimizer': ae_recon_optimizer.state_dict(),
                'ae_reg_optimizer': ae_reg_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Example #25
0
                         dropout=0.3,
                         lr=args.learning_rate,
                         activation_fn=nn.LeakyReLU(0.2)).to(device)

    print("========== Encoder ==========\n{}".format(enc))

    print("========== Decoder ==========\n{}".format(dec))

    print("========== Discriminator ==========\n{}".format(disc))

    for epoch in range(1, args.num_epochs + 1):
        print("========== Start epoch {} at {} ==========".format(
            epoch,
            datetime.now().strftime("%H:%M:%S")))

        train(epoch, enc, dec, disc, prior_size, train_dl, TEXT.vocab, device)
        validate(epoch, enc, dec, disc, prior_size, valid_dl, TEXT.vocab,
                 device)

        print_decoded(enc, dec, gen_dl, vocab=TEXT.vocab, device=device)
        print_sample(dec,
                     sample_size=prior_size,
                     max_seq_len=41,
                     vocab=TEXT.vocab,
                     style_vocab=LABEL.vocab,
                     device=device)

    torch.save(enc.state_dict(), 'rcaae.enc.pt')
    torch.save(dec.state_dict(), 'rcaae.dec.pt')
    torch.save(disc.state_dict(), 'rcaae.disc.pt')
Example #26
0
class AAETrainer(AbstractTrainer):
    def __init__(self, opt):
        super().__init__(opt)

        print('[info] Dataset:', self.opt.dataset)
        print('[info] Alhpa = ', self.opt.alpha)
        print('[info] Latent dimension = ', self.opt.latent_dim)

        self.opt = opt
        self.start_visdom()

    def start_visdom(self):
        self.vis = utils.Visualizer(env='Adversarial AutoEncoder Training',
                                    port=8888)

    def build_network(self):
        print('[info] Build the network architecture')
        self.encoder = Encoder(z_dim=self.opt.latent_dim)
        if self.opt.dataset == 'SMPL':
            num_verts = 6890
        elif self.opt.dataset == 'all_animals':
            num_verts = 3889
        self.decoder = Decoder(num_verts=num_verts, z_dim=self.opt.latent_dim)
        self.discriminator = Discriminator(input_dim=self.opt.latent_dim)

        self.encoder.cuda()
        self.decoder.cuda()
        self.discriminator.cuda()

    def build_optimizer(self):
        print('[info] Build the optimizer')
        self.optim_dis = optim.SGD(self.discriminator.parameters(),
                                   lr=self.opt.learning_rate)
        self.optim_AE = optim.Adam(itertools.chain(self.encoder.parameters(),
                                                   self.decoder.parameters()),
                                   lr=self.opt.learning_rate)

    def build_dataset_train(self):
        train_data = ACAPData(mode='train', name=self.opt.dataset)
        self.num_train_data = len(train_data)
        print('[info] Number of training samples = ', self.num_train_data)
        self.train_loader = torch.utils.data.DataLoader(
            train_data, batch_size=self.opt.batch_size, shuffle=True)

    def build_dataset_valid(self):
        valid_data = ACAPData(mode='valid', name=self.opt.dataset)
        self.num_valid_data = len(valid_data)
        print('[info] Number of validation samples = ', self.num_valid_data)
        self.valid_loader = torch.utils.data.DataLoader(valid_data,
                                                        batch_size=128,
                                                        shuffle=True)

    def build_losses(self):
        print('[info] Build the loss functions')
        self.mseLoss = torch.nn.MSELoss()
        self.ganLoss = torch.nn.BCELoss()

    def print_iteration_stats(self):
        """
        print stats at each iteration
        """
        print(
            '\r[Epoch %d] [Iteration %d/%d] enc = %f dis = %f rec = %f' %
            (self.epoch, self.iteration,
             int(self.num_train_data / self.opt.batch_size),
             self.enc_loss.item(), self.dis_loss.item(), self.rec_loss.item()),
            end='')

    def train_iteration(self):

        self.encoder.train()
        self.decoder.train()
        self.discriminator.train()

        x = self.data.cuda()

        z = self.encoder(x)
        ''' Discriminator '''
        # sample from N(0, I)
        z_real = Variable(torch.randn(z.size(0), z.size(1))).cuda()

        y_real = Variable(torch.ones(z.size(0))).cuda()
        dis_real_loss = self.ganLoss(
            self.discriminator(z_real).view(-1), y_real)

        y_fake = Variable(torch.zeros(z.size(0))).cuda()
        dis_fake_loss = self.ganLoss(self.discriminator(z).view(-1), y_fake)

        self.optim_dis.zero_grad()
        self.dis_loss = 0.5 * (dis_fake_loss + dis_real_loss)
        self.dis_loss.backward(retain_graph=True)
        self.optim_dis.step()
        self.dis_losses.append(self.dis_loss.item())
        ''' Autoencoder '''
        # Encoder hopes to generate latent vectors that are closed to prior.
        y_real = Variable(torch.ones(z.size(0))).cuda()
        self.enc_loss = self.ganLoss(self.discriminator(z).view(-1), y_real)

        # Decoder hopes to make the reconstruction as similar to input as possible.
        rec = self.decoder(z)
        self.rec_loss = self.mseLoss(rec, x)

        # There is a trade-off here:
        # Latent regularization V.S. Reconstruction quality
        self.EG_loss = self.opt.alpha * self.enc_loss + (
            1 - self.opt.alpha) * self.rec_loss

        self.optim_AE.zero_grad()
        self.EG_loss.backward()
        self.optim_AE.step()

        self.enc_losses.append(self.enc_loss.item())
        self.rec_losses.append(self.rec_loss.item())

        self.print_iteration_stats()
        self.increment_iteration()

    def train_epoch(self):

        self.reset_iteration()
        self.dis_losses = []
        self.enc_losses = []
        self.rec_losses = []
        for step, data in enumerate(self.train_loader):
            self.data = data
            self.train_iteration()

        self.dis_losses = torch.Tensor(self.dis_losses)
        self.dis_losses = torch.mean(self.dis_losses)

        self.enc_losses = torch.Tensor(self.enc_losses)
        self.enc_losses = torch.mean(self.enc_losses)

        self.rec_losses = torch.Tensor(self.rec_losses)
        self.rec_losses = torch.mean(self.rec_losses)

        self.vis.draw_line(win='Encoder Loss', x=self.epoch, y=self.enc_losses)
        self.vis.draw_line(win='Discriminator Loss',
                           x=self.epoch,
                           y=self.dis_losses)
        self.vis.draw_line(win='Reconstruction Loss',
                           x=self.epoch,
                           y=self.rec_losses)

    def valid_iteration(self):

        self.encoder.eval()
        self.decoder.eval()
        self.discriminator.eval()

        x = self.data.cuda()
        z = self.encoder(x)
        recon = self.decoder(z)

        # loss
        rec_loss = self.mseLoss(recon, x)
        self.rec_loss.append(rec_loss.item())
        self.increment_iteration()

    def valid_epoch(self):
        self.reset_iteration()
        self.rec_loss = []
        for step, data in enumerate(self.valid_loader):
            self.data = data
            self.valid_iteration()

        self.rec_loss = torch.Tensor(self.rec_loss)
        self.rec_loss = torch.mean(self.rec_loss)
        self.vis.draw_line(win='Valid reconstruction loss',
                           x=self.epoch,
                           y=self.rec_loss)

    def save_network(self):
        print("\n[info] saving net...")
        torch.save(self.encoder.state_dict(),
                   f"{self.opt.save_path}/Encoder.pth")
        torch.save(self.decoder.state_dict(),
                   f"{self.opt.save_path}/Decoder.pth")
        torch.save(self.discriminator.state_dict(),
                   f"{self.opt.save_path}/Discriminator.pth")
            y, M = encoder(x)
            M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0)
            loss_mutual_information = loss_fn(y, M, M_prime)
            loss_total = loss_mutual_information
            train_loss.append(loss_total.item())
            batch.set_description(str(epoch) + ' Loss: ' + str(stats.mean(train_loss[-20:])))
            loss_total.backward()
            encoder_optim.step()
            loss_optim.step()

        if epoch % 10 == 0:
            root = Path(r'models')
            enc_file = root / Path('encoder' + str(epoch) + '.wgt')
            loss_file = root / Path('loss' + str(epoch) + '.wgt')
            enc_file.parent.mkdir(parents=True, exist_ok=True)
            torch.save(encoder.state_dict(), str(enc_file))
            torch.save(loss_fn.state_dict(), str(loss_file))

        if epoch > 1: 
            with open('loss.pickle', 'rb') as handle:
                loss_dict = pickle.load(handle)
                loss_dict[str(epoch)] = stats.mean(train_loss[-20:])
            with open('loss.pickle', 'wb') as handle:
                pickle.dump(loss_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
        else:
            with open('loss.pickle', 'wb') as handle:
                loss_dict = {}
                loss_dict[str(epoch)] = stats.mean(train_loss[-20:])
                pickle.dump(loss_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
            
Example #28
0
def main(args):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on {device}")

    if not os.path.exists(args.models_dir):
        os.makedirs(args.models_dir)

    if args.build_vocab:
        print(
            f"Building vocabulary from captions at {args.captions_json} and with count threshold={args.threshold}"
        )
        vocab_object = build_vocab(args.captions_json, args.threshold)
        with open(args.vocab_path, "wb") as vocab_f:
            pickle.dump(vocab_object, vocab_f)
        print(
            f"Saved the vocabulary object to {args.vocab_path}, total size={len(vocab_object)}"
        )
    else:
        with open(args.vocab_path, 'rb') as f:
            vocab_object = pickle.load(f)
        print(
            f"Loaded the vocabulary object from {args.vocab_path}, total size={len(vocab_object)}"
        )

    if args.glove_embed_path is not None:
        with open(args.glove_embed_path, 'rb') as f:
            glove_embeddings = pickle.load(f)
        print(
            f"Loaded the glove embeddings from {args.glove_embed_path}, total size={len(glove_embeddings)}"
        )

        # We are using 300d glove embeddings
        args.embed_size = 300

        weights_matrix = np.zeros((len(vocab_object), args.embed_size))

        for word, index in vocab_object.word2index.items():
            if word in glove_embeddings:
                weights_matrix[index] = glove_embeddings[word]
            else:
                weights_matrix[index] = np.random.normal(
                    scale=0.6, size=(args.embed_size, ))

        weights_matrix = torch.from_numpy(weights_matrix).float().to(device)

    else:
        weights_matrix = None

    img_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    train_dataset = cocoDataset(args.image_root, args.captions_json,
                                vocab_object, img_transforms)
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=collate_fn)

    encoder = Encoder(args.resnet_size, (3, 224, 224),
                      args.embed_size).to(device)
    decoder = Decoder(args.rnn_type, weights_matrix, len(vocab_object),
                      args.embed_size, args.hidden_size).to(device)

    encoder_learnable = list(encoder.linear.parameters())
    decoder_learnable = list(decoder.rnn.parameters()) + list(
        decoder.linear.parameters())
    if args.glove_embed_path is None:
        decoder_learnable = decoder_learnable + list(
            decoder.embedding.parameters())

    criterion = nn.CrossEntropyLoss()
    params = encoder_learnable + decoder_learnable
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    start_epoch = 0

    if args.ckpt_path is not None:
        model_ckpt = torch.load(args.ckpt_path)
        start_epoch = model_ckpt['epoch'] + 1
        prev_loss = model_ckpt['loss']
        encoder.load_state_dict(model_ckpt['encoder'])
        decoder.load_state_dict(model_ckpt['decoder'])
        optimizer.load_state_dict(model_ckpt['optimizer'])
        print(
            f"Loaded model and optimizer state from {args.ckpt_path}; start epoch at {start_epoch}; prev loss={prev_loss}"
        )

    total_examples = len(train_dataloader)
    for epoch in range(start_epoch, args.num_epochs):
        for i, (images, captions, lengths) in enumerate(train_dataloader):
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True).data

            image_embeddings = encoder(images)
            outputs = decoder(image_embeddings, captions, lengths)

            loss = criterion(outputs, targets)

            decoder.zero_grad()
            encoder.zero_grad()

            loss.backward()
            optimizer.step()

            if i % args.log_interval == 0:
                loss_val = "{:.4f}".format(loss.item())
                perplexity_val = "{:5.4f}".format(np.exp(loss.item()))
                print(
                    f"epoch=[{epoch}/{args.num_epochs}], iteration=[{i}/{total_examples}], loss={loss_val}, perplexity={perplexity_val}"
                )

        torch.save(
            {
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': loss
            },
            os.path.join(args.models_dir,
                         'model-after-epoch-{}.ckpt'.format(epoch)))