tgt_dir = 'webcam' test_dir = 'test' cuda = torch.cuda.is_available() test_loader = get_dataloader(data_dir, tgt_dir, batch_size=15, train=False) # lam for confusion lam = 0.01 # nu for soft nu = 0.1 # load the pretrained and fine-tuned alex model encoder = Encoder() cl_classifier = ClassClassifier(num_classes=31) dm_classifier = DomainClassifier() encoder.load_state_dict(torch.load('./checkpoints/a2w/src_encoder_final.pth')) cl_classifier.load_state_dict( torch.load('./checkpoints/a2w/src_classifier_final.pth')) src_train_loader = get_dataloader(data_dir, src_dir, batch_size, train=True) tgt_train_loader = get_dataloader(data_dir, tgt_train_dir, batch_size, train=True) criterion = nn.CrossEntropyLoss() # criterion_kl = nn.KLDivLoss() if cuda: criterion = criterion.cuda() cl_classifier = cl_classifier.cuda() dm_classifier = dm_classifier.cuda() encoder = encoder.cuda()
def main(): parser = argparse.ArgumentParser(description='Test SimCLR model') parser.add_argument('--EPOCHS', default=1, type=int, help='Number of epochs for training') parser.add_argument('--BATCH_SIZE', default=64, type=int, help='Batch size') parser.add_argument( '--LOG_INT', type=int, default=100, help='how many batches to wait before logging training status') parser.add_argument('--SAVED_MODEL', default='./ckpt/model.pth') args = parser.parse_args() use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") saved_model = Encoder() saved_model.load_state_dict(torch.load(args.SAVED_MODEL)) # Freeze weights in the pretrained model for param in saved_model.parameters(): param.requires_grad = False test_saved_model = Classifier(saved_model).to(device) criterion = nn.CrossEntropyLoss() optimizer_saved = optim.Adam(test_saved_model.fc.parameters()) standard_transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010]) ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=False, transform=standard_transform) train_loader = torch.utils.data.DataLoader(trainset, batch_size=args.BATCH_SIZE, shuffle=True) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=standard_transform) test_loader = torch.utils.data.DataLoader(testset, batch_size=args.BATCH_SIZE, shuffle=True) for epoch in range(args.EPOCHS): print("Performance on the saved model") train(args, test_saved_model, device, train_loader, optimizer_saved, epoch, criterion) test(args, test_saved_model, device, test_loader)
class Dreamer(Agent): # The agent has its own replay buffer, update, act def __init__(self, args): """ All paras are passed by args :param args: a dict that includes parameters """ super().__init__() self.args = args # Initialise model parameters randomly self.transition_model = TransitionModel( args.belief_size, args.state_size, args.action_size, args.hidden_size, args.embedding_size, args.dense_act).to(device=args.device) self.observation_model = ObservationModel( args.symbolic, args.observation_size, args.belief_size, args.state_size, args.embedding_size, activation_function=(args.dense_act if args.symbolic else args.cnn_act)).to(device=args.device) self.reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.encoder = Encoder(args.symbolic, args.observation_size, args.embedding_size, args.cnn_act).to(device=args.device) self.actor_model = ActorModel( args.action_size, args.belief_size, args.state_size, args.hidden_size, activation_function=args.dense_act, fix_speed=args.fix_speed, throttle_base=args.throttle_base).to(device=args.device) self.value_model = ValueModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.value_model2 = ValueModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.pcont_model = PCONTModel(args.belief_size, args.state_size, args.hidden_size, args.dense_act).to(device=args.device) self.target_value_model = deepcopy(self.value_model) self.target_value_model2 = deepcopy(self.value_model2) for p in self.target_value_model.parameters(): p.requires_grad = False for p in self.target_value_model2.parameters(): p.requires_grad = False # setup the paras to update self.world_param = list(self.transition_model.parameters())\ + list(self.observation_model.parameters())\ + list(self.reward_model.parameters())\ + list(self.encoder.parameters()) if args.pcont: self.world_param += list(self.pcont_model.parameters()) # setup optimizer self.world_optimizer = optim.Adam(self.world_param, lr=args.world_lr) self.actor_optimizer = optim.Adam(self.actor_model.parameters(), lr=args.actor_lr) self.value_optimizer = optim.Adam(list(self.value_model.parameters()) + list(self.value_model2.parameters()), lr=args.value_lr) # setup the free_nat to self.free_nats = torch.full( (1, ), args.free_nats, dtype=torch.float32, device=args.device) # Allowed deviation in KL divergence # TODO: change it to the new replay buffer, in buffer.py self.D = ExperienceReplay(args.experience_size, args.symbolic, args.observation_size, args.action_size, args.bit_depth, args.device) if self.args.auto_temp: # setup for learning of alpha term (temp of the entropy term) self.log_temp = torch.zeros(1, requires_grad=True, device=args.device) self.target_entropy = -np.prod( args.action_size if not args.fix_speed else self.args. action_size - 1).item() # heuristic value from SAC paper self.temp_optimizer = optim.Adam( [self.log_temp], lr=args.value_lr) # use the same value_lr # TODO: print out the param used in Dreamer # var_counts = tuple(count_vars(module) for module in [self., self.ac.q1, self.ac.q2]) # print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts) # def process_im(self, image, image_size=None, rgb=None): # # Resize, put channel first, convert it to a tensor, centre it to [-0.5, 0.5] and add batch dimenstion. # # def preprocess_observation_(observation, bit_depth): # # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5]) # observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_( # 0.5) # Quantise to given bit depth and centre # observation.add_(torch.rand_like(observation).div_( # 2 ** bit_depth)) # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images) # # image = image[40:, :, :] # clip the above 40 rows # image = torch.tensor(cv2.resize(image, (40, 40), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1), # dtype=torch.float32) # Resize and put channel first # # preprocess_observation_(image, self.args.bit_depth) # return image.unsqueeze(dim=0) def process_im(self, images, image_size=None, rgb=None): images = cv2.resize(images, (40, 40)) images = np.dot(images, [0.299, 0.587, 0.114]) obs = torch.tensor(images, dtype=torch.float32).div_(255.).sub_(0.5).unsqueeze( dim=0) # shape [1, 40, 40], range:[-0.5,0.5] return obs.unsqueeze(dim=0) # add batch dimension def append_buffer(self, new_traj): # append new collected trajectory, not implement the data augmentation # shape of new_traj: [(o, a, r, d) * steps] for state in new_traj: observation, action, reward, done = state self.D.append(observation, action.cpu(), reward, done) def _compute_loss_world(self, state, data): # unpackage data beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state observations, rewards, nonterminals = data # observation_loss = F.mse_loss( # bottle(self.observation_model, (beliefs, posterior_states)), # observations[1:], # reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) # # reward_loss = F.mse_loss( # bottle(self.reward_model, (beliefs, posterior_states)), # rewards[1:], # reduction='none').mean(dim=(0,1)) observation_loss = F.mse_loss( bottle(self.observation_model, (beliefs, posterior_states)), observations, reduction='none').sum( dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1)) reward_loss = F.mse_loss(bottle(self.reward_model, (beliefs, posterior_states)), rewards, reduction='none').mean(dim=(0, 1)) # TODO: 5 # transition loss kl_loss = torch.max( kl_divergence( Independent(Normal(posterior_means, posterior_std_devs), 1), Independent(Normal(prior_means, prior_std_devs), 1)), self.free_nats).mean(dim=(0, 1)) # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape) if self.args.pcont: pcont_loss = F.binary_cross_entropy( bottle(self.pcont_model, (beliefs, posterior_states)), nonterminals) # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states))) # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1)) return observation_loss, self.args.reward_scale * reward_loss, kl_loss, ( self.args.pcont_scale * pcont_loss if self.args.pcont else 0) def _compute_loss_actor(self, imag_beliefs, imag_states, imag_ac_logps=None): # reward and value prediction of imagined trajectories imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states)) imag_values = bottle(self.value_model, (imag_beliefs, imag_states)) imag_values2 = bottle(self.value_model2, (imag_beliefs, imag_states)) imag_values = torch.min(imag_values, imag_values2) with torch.no_grad(): if self.args.pcont: pcont = bottle(self.pcont_model, (imag_beliefs, imag_states)) else: pcont = self.args.discount * torch.ones_like(imag_rewards) pcont = pcont.detach() if imag_ac_logps is not None: imag_values[ 1:] -= self.args.temp * imag_ac_logps # add entropy here returns = cal_returns(imag_rewards[:-1], imag_values[:-1], imag_values[-1], pcont[:-1], lambda_=self.args.disclam) discount = torch.cumprod( torch.cat([torch.ones_like(pcont[:1]), pcont[:-2]], 0), 0) discount = discount.detach() assert list(discount.size()) == list(returns.size()) actor_loss = -torch.mean(discount * returns) return actor_loss def _compute_loss_critic(self, imag_beliefs, imag_states, imag_ac_logps=None): with torch.no_grad(): # calculate the target with the target nn target_imag_values = bottle(self.target_value_model, (imag_beliefs, imag_states)) target_imag_values2 = bottle(self.target_value_model2, (imag_beliefs, imag_states)) target_imag_values = torch.min(target_imag_values, target_imag_values2) imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states)) if self.args.pcont: pcont = bottle(self.pcont_model, (imag_beliefs, imag_states)) else: pcont = self.args.discount * torch.ones_like(imag_rewards) # print("check pcont", pcont) if imag_ac_logps is not None: target_imag_values[1:] -= self.args.temp * imag_ac_logps returns = cal_returns(imag_rewards[:-1], target_imag_values[:-1], target_imag_values[-1], pcont[:-1], lambda_=self.args.disclam) target_return = returns.detach() value_pred = bottle(self.value_model, (imag_beliefs, imag_states))[:-1] value_pred2 = bottle(self.value_model2, (imag_beliefs, imag_states))[:-1] value_loss = F.mse_loss(value_pred, target_return, reduction="none").mean(dim=(0, 1)) value_loss2 = F.mse_loss(value_pred2, target_return, reduction="none").mean(dim=(0, 1)) value_loss += value_loss2 return value_loss def _latent_imagination(self, beliefs, posterior_states, with_logprob=False): # Rollout to generate imagined trajectories chunk_size, batch_size, _ = list( posterior_states.size()) # flatten the tensor flatten_size = chunk_size * batch_size posterior_states = posterior_states.detach().reshape(flatten_size, -1) beliefs = beliefs.detach().reshape(flatten_size, -1) imag_beliefs, imag_states, imag_ac_logps = [beliefs ], [posterior_states], [] for i in range(self.args.planning_horizon): imag_action, imag_ac_logp = self.actor_model( imag_beliefs[-1].detach(), imag_states[-1].detach(), deterministic=False, with_logprob=with_logprob, ) imag_action = imag_action.unsqueeze(dim=0) # add time dim # print(imag_states[-1].shape, imag_action.shape, imag_beliefs[-1].shape) imag_belief, imag_state, _, _ = self.transition_model( imag_states[-1], imag_action, imag_beliefs[-1]) imag_beliefs.append(imag_belief.squeeze(dim=0)) imag_states.append(imag_state.squeeze(dim=0)) if with_logprob: imag_ac_logps.append(imag_ac_logp.squeeze(dim=0)) imag_beliefs = torch.stack(imag_beliefs, dim=0).to( self.args.device ) # shape [horizon+1, (chuck-1)*batch, belief_size] imag_states = torch.stack(imag_states, dim=0).to(self.args.device) if with_logprob: imag_ac_logps = torch.stack(imag_ac_logps, dim=0).to( self.args.device) # shape [horizon, (chuck-1)*batch] return imag_beliefs, imag_states, imag_ac_logps if with_logprob else None def update_parameters(self, gradient_steps): loss_info = [] # used to record loss for s in tqdm(range(gradient_steps)): # get state and belief of samples observations, actions, rewards, nonterminals = self.D.sample( self.args.batch_size, self.args.chunk_size) # print("check sampled rewrads", rewards) init_belief = torch.zeros(self.args.batch_size, self.args.belief_size, device=self.args.device) init_state = torch.zeros(self.args.batch_size, self.args.state_size, device=self.args.device) # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once) # beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model( # init_state, # actions[:-1], # init_belief, # bottle(self.encoder, (observations[1:], )), # nonterminals[:-1]) beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model( init_state, actions, init_belief, bottle(self.encoder, (observations, )), nonterminals) # TODO: 4 # update paras of world model world_model_loss = self._compute_loss_world( state=(beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs), data=(observations, rewards, nonterminals)) observation_loss, reward_loss, kl_loss, pcont_loss = world_model_loss self.world_optimizer.zero_grad() (observation_loss + reward_loss + kl_loss + pcont_loss).backward() nn.utils.clip_grad_norm_(self.world_param, self.args.grad_clip_norm, norm_type=2) self.world_optimizer.step() # freeze params to save memory for p in self.world_param: p.requires_grad = False for p in self.value_model.parameters(): p.requires_grad = False for p in self.value_model2.parameters(): p.requires_gard = False # latent imagination imag_beliefs, imag_states, imag_ac_logps = self._latent_imagination( beliefs, posterior_states, with_logprob=self.args.with_logprob) # update temp if self.args.auto_temp: temp_loss = -( self.log_temp * (imag_ac_logps[0] + self.target_entropy).detach()).mean() self.temp_optimizer.zero_grad() temp_loss.backward() self.temp_optimizer.step() self.args.temp = self.log_temp.exp() # update actor actor_loss = self._compute_loss_actor(imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps) self.actor_optimizer.zero_grad() actor_loss.backward() nn.utils.clip_grad_norm_(self.actor_model.parameters(), self.args.grad_clip_norm, norm_type=2) self.actor_optimizer.step() for p in self.world_param: p.requires_grad = True for p in self.value_model.parameters(): p.requires_grad = True for p in self.value_model2.parameters(): p.requires_grad = True # update critic imag_beliefs = imag_beliefs.detach() imag_states = imag_states.detach() critic_loss = self._compute_loss_critic( imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps) self.value_optimizer.zero_grad() critic_loss.backward() nn.utils.clip_grad_norm_(self.value_model.parameters(), self.args.grad_clip_norm, norm_type=2) nn.utils.clip_grad_norm_(self.value_model2.parameters(), self.args.grad_clip_norm, norm_type=2) self.value_optimizer.step() loss_info.append([ observation_loss.item(), reward_loss.item(), kl_loss.item(), pcont_loss.item() if self.args.pcont else 0, actor_loss.item(), critic_loss.item() ]) # finally, update target value function every #gradient_steps with torch.no_grad(): self.target_value_model.load_state_dict( self.value_model.state_dict()) with torch.no_grad(): self.target_value_model2.load_state_dict( self.value_model2.state_dict()) return loss_info def infer_state(self, observation, action, belief=None, state=None): """ Infer belief over current state q(s_t|o≤t,a<t) from the history, return updated belief and posterior_state at time t returned shape: belief/state [belief/state_dim] (remove the time_dim) """ # observation is obs.to(device), action.shape=[act_dim] (will add time dim inside this fn), belief.shape belief, _, _, _, posterior_state, _, _ = self.transition_model( state, action.unsqueeze(dim=0), belief, self.encoder(observation).unsqueeze( dim=0)) # Action and observation need extra time dimension belief, posterior_state = belief.squeeze( dim=0), posterior_state.squeeze( dim=0) # Remove time dimension from belief/state return belief, posterior_state def select_action(self, state, deterministic=False): # get action with the inputs get from fn: infer_state; return a numpy with shape [batch, act_size] belief, posterior_state = state action, _ = self.actor_model(belief, posterior_state, deterministic=deterministic, with_logprob=False) if not deterministic and not self.args.with_logprob: print("e") action = Normal(action, self.args.expl_amount).rsample() # clip the angle action[:, 0].clamp_(min=self.args.angle_min, max=self.args.angle_max) # clip the throttle if self.args.fix_speed: action[:, 1] = self.args.throttle_base else: action[:, 1].clamp_(min=self.args.throttle_min, max=self.args.throttle_max) print("action", action) # return action.cup().numpy() return action # this is a Tonsor.cuda def import_parameters(self, params): # only import or export the parameters used when local rollout self.encoder.load_state_dict(params["encoder"]) self.actor_model.load_state_dict(params["policy"]) self.transition_model.load_state_dict(params["transition"]) def export_parameters(self): """ return the model paras used for local rollout """ params = { "encoder": self.encoder.cpu().state_dict(), "policy": self.actor_model.cpu().state_dict(), "transition": self.transition_model.cpu().state_dict() } self.encoder.to(self.args.device) self.actor_model.to(self.args.device) self.transition_model.to(self.args.device) return params
type=float, nargs='+', help=' x, y\ values from the disruption to decodding') return parser.parse_args() args = manage() cuda = args.cuda and torch.cuda.is_available() encoder = Encoder() decoder = Decoder() if cuda: encoder.load_state_dict(torch.load(args.encoder_path)) encoder.cuda() else: encoder.load_state_dict(torch.load(args.encoder_path, map_location='cpu')) if cuda: decoder.load_state_dict(torch.load(args.decoder_path)) decoder.cuda() else: decoder.load_state_dict(torch.load(args.decoder_path, map_location='cpu')) if args.image_path: image = cv2.imread(args.image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) code = encoder(transforms.ToTensor()(transforms.ToPILImage()(image)).view( -1, 784))
metrics['episodes'].append(s) # Initialise model parameters randomly transition_model = TransitionModel(args.belief_size, args.state_size, env.action_size, args.hidden_size, args.embedding_size, args.activation_function).to(device=args.device) observation_model = ObservationModel(args.symbolic_env, env.observation_size, args.belief_size, args.state_size, args.embedding_size, args.activation_function).to(device=args.device) reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device) encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device) param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters()) optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon) if args.load_checkpoint > 0: model_dicts = torch.load(os.path.join(results_dir, 'models_%d.pth' % args.load_checkpoint)) transition_model.load_state_dict(model_dicts['transition_model']) observation_model.load_state_dict(model_dicts['observation_model']) reward_model.load_state_dict(model_dicts['reward_model']) encoder.load_state_dict(model_dicts['encoder']) optimiser.load_state_dict(model_dicts['optimiser']) planner = MPCPlanner(env.action_size, args.planning_horizon, args.optimisation_iters, args.candidates, args.top_candidates, transition_model, reward_model) global_prior = Normal(torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device)) # Global prior N(0, I) free_nats = torch.full((1, ), args.free_nats, device=args.device) # Allowed deviation in KL divergence def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, test): # Infer belief over current state q(s_t|o≤t,a<t) from the history belief, _, _, _, posterior_state, _, _ = transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoder(observation).unsqueeze(dim=0)) # Action and observation need extra time dimension belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0) # Remove time dimension from belief/state action = planner(belief, posterior_state) # Get action from planner(q(s_t|o≤t,a<t), p) if not test: action = action + args.action_noise * torch.randn_like(action) # Add exploration noise ε ~ p(ε) to the action next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu()) # Perform environment step (action repeats handled internally) return belief, posterior_state, action, next_observation, reward, done
# model # Pretrained Model alexnet = torchvision.models.alexnet(pretrained=True) pretrained_dict = alexnet.state_dict() # Train source data # Model parameters src_encoder = Encoder() src_classifier = ClassClassifier(num_classes=31) src_encoder_dict = src_encoder.state_dict() # Load pretrained model # filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in src_encoder_dict} # overwrite entries in the existing state dict src_encoder_dict.update(pretrained_dict) # load the new state dict src_encoder.load_state_dict(src_encoder_dict) optimizer = optim.SGD( list(src_encoder.parameters()) + list(src_classifier.parameters()), lr=lr, momentum=momentum, weight_decay=weight_decay) criterion = nn.CrossEntropyLoss() if cuda: src_encoder = src_encoder.cuda() src_classifier = src_classifier.cuda() criterion = criterion.cuda() src_encoder.train() src_classifier.train()
def main(): print('Training parameters Initialized') training_parameters = TrainingParameters( start_epoch = 0, epochs = 120, # number of epochs to train for epochs_since_improvement = 0, # Epochs since improvement in BLEU score batch_size = 32, workers = 1, # for data-loading; right now, only 1 works with h5py fine_tune_encoder = True, # fine-tune encoder encoder_lr = 1e-4, # learning rate for encoder, if fine-tuning is used decoder_lr = 4e-4, # learning rate for decoder grad_clip = 5.0, # clip gradients at an absolute value of alpha_c = 1.0, # regularization parameter for 'doubly stochastic attention' best_bleu4 = 0.0, # BLEU-4 score right now print_freq = 100, # print training/validation stats every __ batches checkpoint = './Result/BEST_checkpoint_flickr8k_5_captions_per_image_5_minimum_word_frequency.pth.tar' # path to checkpoint, None if none # checkpoint = None ) print('Loading Word-Map') word_map_file = os.path.join(data_folder,'WORDMAP_' + data_name + '.json') with open(word_map_file, 'r') as j: word_map = json.load(j) print('Creating Model') if training_parameters.checkpoint is None: encoder = Encoder() encoder.fine_tune(training_parameters.fine_tune_encoder) encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()), lr=training_parameters.encoder_lr) if training_parameters.fine_tune_encoder else None decoder = Decoder(attention_dimension = attention_dimension, embedding_dimension = embedding_dimension, hidden_dimension = hidden_dimension, vocab_size = len(word_map), device = device, dropout = dropout) decoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, decoder.parameters()), lr=training_parameters.decoder_lr) else: checkpoint = torch.load(training_parameters.checkpoint) training_parameters.start_epoch = checkpoint['epoch'] + 1 training_parameters.epochs_since_improvement = checkpoint['epochs_since_improvement'] training_parameters.best_bleu4 = checkpoint['bleu4'] encoder = Encoder() encoder.load_state_dict(checkpoint['encoder_state_dict']) encoder_optimizer = checkpoint['encoder_optimizer'] decoder = Decoder(attention_dimension = attention_dimension, embedding_dimension = embedding_dimension, hidden_dimension = hidden_dimension, vocab_size = len(word_map), device = device, dropout = dropout) decoder.load_state_dict(checkpoint['decoder_state_dict']) decoder_optimizer = checkpoint['decoder_optimizer'] if training_parameters.fine_tune_encoder is True and encoder_optimizer is None: encoder.fine_tune(training_parameters.fine_tune_encoder) encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()), lr=training_parameters.encoder_lr) encoder.to(device) decoder.to(device) criterion = nn.CrossEntropyLoss().to(device) normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) print('Creating Data Loaders') train_dataloader = torch.utils.data.DataLoader( CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])), batch_size=training_parameters.batch_size, shuffle=True) validation_dataloader = torch.utils.data.DataLoader( CaptionDataset(data_folder, data_name, 'VALID', transform=transforms.Compose([normalize])), batch_size=training_parameters.batch_size, shuffle=True, pin_memory=True) for epoch in range(training_parameters.start_epoch, training_parameters.epochs): if training_parameters.epochs_since_improvement == 20: break if training_parameters.epochs_since_improvement > 0 and training_parameters.epochs_since_improvement % 8 == 0: adjust_learning_rate(decoder_optimizer, 0.8) if training_parameters.fine_tune_encoder: adjust_learning_rate(encoder_optimizer, 0.8) train(train_loader = train_dataloader, encoder = encoder, decoder = decoder, criterion = criterion, encoder_optimizer = encoder_optimizer, decoder_optimizer = decoder_optimizer, epoch = epoch, device = device, training_parameters = training_parameters) recent_bleu4_score = validate(validation_loader = validation_dataloader, encoder = encoder, decoder = decoder, criterion = criterion, word_map = word_map, device = device, training_parameters = training_parameters) is_best_score = recent_bleu4_score > training_parameters.best_bleu4 training_parameters.best_bleu4 = max(recent_bleu4_score, training_parameters.best_bleu4) if not is_best_score: training_parameters.epochs_since_improvement += 1 print('\nEpochs since last improvement : %d\n' % (training_parameters.epochs_since_improvement)) else: training_parameters.epochs_since_improvement = 0 save_checkpoint(data_name, epoch, training_parameters.epochs_since_improvement, encoder, decoder, encoder_optimizer, decoder_optimizer, recent_bleu4_score, is_best_score)
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq' word_map_file = 'datasets/caption_data/WORDMAP_' + data_name + '.json' checkpoint = 'logs/tmp/BEST_MODEL.pth.tar' device = torch.device("cuda" if torch.cuda.is_available() else "cpu") with open(word_map_file, 'r') as f: word_map = json.load(f) rev_word_map = {v: k for k, v in word_map.items()} vocab_size = len(word_map) # 载入模型 encoder = Encoder() decoder = DecoderWithAttention(512, 512, 512, vocab_size) checkpoint = torch.load(checkpoint) encoder.load_state_dict(checkpoint['encoder']) decoder.load_state_dict(checkpoint['decoder']) encoder.to(device) decoder.to(device) encoder.eval() decoder.eval() preprocess = T.Compose([ T.Resize(size=(256, 256)), T.ToTensor(), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def get_image_caption(ori_img):
"WARNING: You have a CUDA device, so you should probably run with --cuda" ) ###### Definition of variables ###### # Networks encoder = Encoder(input_nc=input_nc) decoder_A2B = Decoder(output_nc=output_nc) decoder_B2A = Decoder(output_nc=output_nc) if activate_cuda: encoder.cuda() decoder_A2B.cuda() decoder_B2A.cuda() # Load state dicts encoder.load_state_dict(torch.load(encoder_path)) decoder_A2B.load_state_dict(torch.load(decoder_A2B_path)) decoder_B2A.load_state_dict(torch.load(decoder_B2A_path)) # Set model's test mode encoder.eval() decoder_A2B.eval() decoder_B2A.eval() # Inputs & targets memory allocation Tensor = torch.cuda.FloatTensor if activate_cuda else torch.Tensor input_A = Tensor(batch_size, input_nc, image_size, image_size) input_B = Tensor(batch_size, output_nc, image_size, image_size) # Dataset loader transforms_ = [
def make_representations(arguments, device): """ Creates representations for all data. :param arguments: Dictionary containing arguments. :param device: PyTorch device object. """ # Loads training and testing data. train_data = Dataset(arguments, "train") test_data = Dataset(arguments, "test") # Creates the data loaders for the training and testing data. training_data_loader = DataLoader(train_data, batch_size=arguments["batch_size"], shuffle=False, num_workers=arguments["data_workers"], pin_memory=False, drop_last=False) testing_data_loader = DataLoader(test_data, batch_size=arguments["batch_size"], shuffle=False, num_workers=arguments["data_workers"], pin_memory=False, drop_last=False) log(arguments, "Loaded Datasets") # Initialises the encoder. encoder = Encoder(0, arguments["image_size"], arguments["pretrained"] == "imagenet") # Loads weights from pretrained Contrastive Predictive Coding model. if arguments["pretrained"].lower() == "cpc": encoder_path = os.path.join( arguments["model_dir"], f"{arguments['experiment']}_encoder_best.pt") encoder.load_state_dict(torch.load(encoder_path, map_location=device), strict=False) # Sets the model to evaluation mode. encoder.eval() # Moves the model to the selected device. encoder.to(device) # If 16 bit precision is being used change the model and optimiser precision. if arguments["precision"] == 16: encoder = amp.initialize(encoder, opt_level="O2", verbosity=False) # Checks if precision level is supported and if not defaults to 32. elif arguments["precision"] != 32: log( arguments, "Only 16 and 32 bit precision supported. Defaulting to 32 bit precision." ) log(arguments, "Models Initialised") # Creates a folder if one does not exist. os.makedirs(os.path.dirname(arguments["representation_dir"]), exist_ok=True) # Creates the HDF5 files used to store the training and testing data representations. train_representations = HDF5Handler( os.path.join(arguments["representation_dir"], f"{arguments['experiment']}_train.h5"), 'x', (encoder.encoder_size, )) test_representations = HDF5Handler( os.path.join(arguments["representation_dir"], f"{arguments['experiment']}_test.h5"), 'x', (encoder.encoder_size, )) log(arguments, "HDF5 Representation Files Created.") # Starts a timer. start_time = time.time() # Performs a representation generation with no gradients. with torch.no_grad(): # Loops through the training data. num_batches = 0 for images, _ in training_data_loader: # Loads the image batch into memory. images = images.to(device) # Gets the representations from of the image batch from the encoder. representations = encoder.forward_features(images) # Moves the representations to the CPU. representations = representations.cpu().data.numpy() # Adds the batch representations to the HDF5 file. train_representations.append(representations) # Prints information about representation extraction process. num_batches += 1 if num_batches % arguments["log_intervals"] == 0: print( f"Training Batches: {num_batches}/{len(train_data) // arguments['batch_size']}" ) # Loops through the testing data. num_batches = 0 for images, _ in testing_data_loader: # Loads the image batch into memory. images = images.to(device) # Gets the representations from of the image batch from the encoder. representations = encoder.forward_features(images) # Moves the representations to the CPU. representations = representations.cpu().data.numpy() # Adds the batch representations to the HDF5 file. test_representations.append(representations) # Prints information about representation extraction process. num_batches += 1 if num_batches % arguments["log_intervals"] == 0: print( f"Testing Batches: {num_batches}/{len(test_data) // arguments['batch_size']}" ) print( f"Representations from {arguments['experiment']} encoder created in {int(time.time() - start_time)}s" )
opt_encoder = torch.optim.Adam(encoder.parameters(), lr=knobs["lr_encoder"]) opt_decoder = torch.optim.Adam(decoder.parameters(), lr=knobs["lr_decoder"]) collector_reconstruction_loss = Collector() collector_wasserstein_penalty = Collector() collector_fooling_term = Collector() collector_codes_min = Collector() collector_codes_max = Collector() if knobs["resume"]: writer = SummaryWriter(log_dir_last_modified) checkpoint_dir = checkpoints_dir_last_modified checkpoint = torch.load(checkpoint_dir) starting_epoch = checkpoint["epoch"] iteration = checkpoint["iteration"] encoder.load_state_dict(checkpoint["encoder_state_dict"]) decoder.load_state_dict(checkpoint["decoder_state_dict"]) opt_encoder.load_state_dict(checkpoint["opt_encoder_state_dict"]) opt_decoder.load_state_dict(checkpoint["opt_decoder_state_dict"]) else: writer = SummaryWriter(log_dir_local_time) checkpoint_dir = checkpoints_dir_local_time starting_epoch = 1 iteration = 0 encoder.train() decoder.train() for epoch in range(starting_epoch, knobs["num_epochs"] + 1): for batch in loader: iteration += 1
def train_CIFAR10(opt): import torchvision.datasets as datasets import torchvision.transforms as transforms from torchvision.utils import make_grid from matplotlib import pyplot as plt params = get_config(opt.config) save_path = os.path.join( params['save_path'], datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')) os.makedirs(save_path, exist_ok=True) shutil.copy('models.py', os.path.join(save_path, 'models.py')) shutil.copy('train.py', os.path.join(save_path, 'train.py')) shutil.copy(opt.config, os.path.join(save_path, os.path.basename(opt.config))) cuda = torch.cuda.is_available() gpu_ids = [i for i in range(torch.cuda.device_count())] TensorType = torch.cuda.FloatTensor if cuda else torch.Tensor data_path = os.path.join(params['data_root'], 'cifar10') os.makedirs(data_path, exist_ok=True) train_dataset = datasets.CIFAR10(root=data_path, train=True, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) val_dataset = datasets.CIFAR10(root=data_path, train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ])) train_loader = DataLoader(train_dataset, batch_size=params['batch_size'] * len(gpu_ids), shuffle=True, num_workers=params['num_workers'], pin_memory=cuda) val_loader = DataLoader(val_dataset, batch_size=1, num_workers=params['num_workers'], pin_memory=cuda) data_variance = np.var(train_dataset.train_data / 255.0) encoder = Encoder(params['dim'], params['residual_channels'], params['n_layers'], params['d']) decoder = Decoder(params['dim'], params['residual_channels'], params['n_layers'], params['d']) vq = VectorQuantizer(params['k'], params['d'], params['beta'], params['decay'], TensorType) if params['checkpoint'] != None: checkpoint = torch.load(params['checkpoint']) params['start_epoch'] = checkpoint['epoch'] encoder.load_state_dict(checkpoint['encoder']) decoder.load_state_dict(checkpoint['decoder']) vq.load_state_dict(checkpoint['vq']) model = VQVAE(encoder, decoder, vq) if cuda: model = nn.DataParallel(model.cuda(), device_ids=gpu_ids) parameters = list(model.parameters()) opt = torch.optim.Adam([p for p in parameters if p.requires_grad], lr=params['lr']) for epoch in range(params['start_epoch'], params['num_epochs']): train_bar = tqdm(train_loader) for data, _ in train_bar: if cuda: data = data.cuda() opt.zero_grad() vq_loss, data_recon, _ = model(data) recon_error = torch.mean((data_recon - data)**2) / data_variance loss = recon_error + vq_loss.mean() loss.backward() opt.step() train_bar.set_description('Epoch {}: loss {:.4f}'.format( epoch + 1, loss.mean().item())) model.eval() data_val = next(iter(val_loader)) data_val, _ = data_val if cuda: data_val = data_val.cuda() _, data_recon_val, _ = model(data_val) plt.imsave(os.path.join(save_path, 'latest_val_recon.png'), (make_grid(data_recon_val.cpu().data) + 0.5).numpy().transpose(1, 2, 0)) plt.imsave(os.path.join(save_path, 'latest_val_orig.png'), (make_grid(data_val.cpu().data) + 0.5).numpy().transpose( 1, 2, 0)) model.train() torch.save( { 'epoch': epoch, 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'vq': vq.state_dict(), }, os.path.join(save_path, '{}_checkpoint.pth'.format(epoch)))
'data/train.en.txt', 'data/train.vi.txt', args.max_length) English = utils.LanguageModel("English") [English.add_line(i) for i in English_sentences] Vietnamese = utils.LanguageModel("Vietnamese") [Vietnamese.add_line(i) for i in Vietnamese_sentences] English_tokens = [English.tokens_from_line(i) for i in English_sentences] Vietnamese_tokens = [ Vietnamese.tokens_from_line(i) for i in Vietnamese_sentences ] encoder = Encoder(args, input_size=English.num_words) decoder = Decoder(args, output_size=Vietnamese.num_words) args.output_size = Vietnamese.num_words if 'train' in args.mode: if args.checkpoint is True: encoder.load_state_dict( torch.load(os.path.join('model', 'encoder.pkl'))) decoder.load_state_dict( torch.load(os.path.join('model', 'decoder.pkl'))) encoder.eval() decoder.eval() utils.train_handler(args, encoder, decoder, English_tokens, Vietnamese_tokens, English, Vietnamese) torch.save(encoder.state_dict(), os.path.join('model', 'encoder.pkl')) torch.save(decoder.state_dict(), os.path.join('model', 'decoder.pkl')) elif 'translate' in args.mode: encoder.load_state_dict(torch.load(os.path.join('model', 'encoder.pkl'))) decoder.load_state_dict(torch.load(os.path.join('model', 'decoder.pkl'))) encoder.eval() decoder.eval() while True:
def train_model(cfg): tensorboard_path = Path( utils.to_absolute_path("tensorboard")) / cfg.checkpoint_dir checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir)) writer = SummaryWriter(tensorboard_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(**cfg.model.encoder) decoder = Decoder(**cfg.model.decoder) encoder.to(device) decoder.to(device) optimizer = optim.Adam(chain(encoder.parameters(), decoder.parameters()), lr=cfg.training.optimizer.lr) [encoder, decoder], optimizer = amp.initialize([encoder, decoder], optimizer, opt_level="O1") scheduler = optim.lr_scheduler.MultiStepLR( optimizer, milestones=cfg.training.scheduler.milestones, gamma=cfg.training.scheduler.gamma) if cfg.resume: print("Resume checkpoint from: {}:".format(cfg.resume)) resume_path = utils.to_absolute_path(cfg.resume) checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage) encoder.load_state_dict(checkpoint["encoder"]) decoder.load_state_dict(checkpoint["decoder"]) optimizer.load_state_dict(checkpoint["optimizer"]) amp.load_state_dict(checkpoint["amp"]) scheduler.load_state_dict(checkpoint["scheduler"]) global_step = checkpoint["step"] else: global_step = 0 root_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path dataset = SpeechDataset(root=root_path, hop_length=cfg.preprocessing.hop_length, sr=cfg.preprocessing.sr, sample_frames=cfg.training.sample_frames) dataloader = DataLoader(dataset, batch_size=cfg.training.batch_size, shuffle=True, num_workers=cfg.training.n_workers, pin_memory=True, drop_last=True) n_epochs = cfg.training.n_steps // len(dataloader) + 1 start_epoch = global_step // len(dataloader) + 1 for epoch in range(start_epoch, n_epochs + 1): average_recon_loss = average_vq_loss = average_perplexity = 0 for i, (audio, mels, speakers) in enumerate(tqdm(dataloader), 1): audio, mels, speakers = audio.to(device), mels.to( device), speakers.to(device) optimizer.zero_grad() z, vq_loss, perplexity = encoder(mels) output = decoder(audio[:, :-1], z, speakers) recon_loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:]) loss = recon_loss + vq_loss with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1) optimizer.step() scheduler.step() average_recon_loss += (recon_loss.item() - average_recon_loss) / i average_vq_loss += (vq_loss.item() - average_vq_loss) / i average_perplexity += (perplexity.item() - average_perplexity) / i global_step += 1 if global_step % cfg.training.checkpoint_interval == 0: save_checkpoint(encoder, decoder, optimizer, amp, scheduler, global_step, checkpoint_dir) writer.add_scalar("recon_loss/train", average_recon_loss, global_step) writer.add_scalar("vq_loss/train", average_vq_loss, global_step) writer.add_scalar("average_perplexity", average_perplexity, global_step) print("epoch:{}, recon loss:{:.2E}, vq loss:{:.2E}, perpexlity:{:.3f}". format(epoch, average_recon_loss, average_vq_loss, average_perplexity))
def recover_models(device, model="supervised", m=256, n=4, chann_type="AWGN", verbose=False): """ Function to try to recover an already saved system to a channel Args: device (string): Current device that we are working in model (string): Model that wish to be recovered. Options: supervised or alternated chann_type (string): Channel type. Currently only AWGN available n (int): Length of the encoded messages m ((int): Total number of messages that can be encoded Returns: encoder/tx (Object): Recovered Tx/Encoder model decoder/rx (Object): Recovered Rx/Decoder model """ try: if model == "supervised": enc_filename = "%s/%s_%d_%d_encoder.pth" % (MODELS_FOLDER, chann_type, m, n) dec_filename = "%s/%s_%d_%d_decoder.pth" % (MODELS_FOLDER, chann_type, m, n) encoder = Encoder(m=m, n=n) encoder.load_state_dict(torch.load(enc_filename)) if verbose: print('Model loaded from %s.' % enc_filename) # Put them in the correct device and eval mode encoder.to(device) encoder.eval() decoder = Decoder(m=m, n=n) decoder.load_state_dict(torch.load(dec_filename)) if verbose: print('Model loaded from %s.' % dec_filename) decoder.to(device) decoder.eval() return encoder, decoder else: tx_filename = "%s/%s_%d_%d_tx.pth" % (MODELS_FOLDER, chann_type, m, n) rx_filename = "%s/%s_%d_%d_rx.pth" % (MODELS_FOLDER, chann_type, m, n) tx = Transmitter(m=m, n=n) tx.load_state_dict(torch.load(tx_filename)) if verbose: print('Model loaded from %s.' % tx_filename) # Put them in the correct device and eval mode tx.to(device) tx.eval() rx = Receiver(m=m, n=n) rx.load_state_dict(torch.load(rx_filename)) if verbose: print('Model loaded from %s.' % rx_filename) rx.to(device) rx.eval() return tx, rx except: raise NameError("Something went wrong loading file for system (%s)" % (chann_type))
def main(index, args): device = xm.xla_device() gen_net = Generator(args).to(device) dis_net = Discriminator(args).to(device) enc_net = Encoder(args).to(device) def weights_init(m): classname = m.__class__.__name__ if classname.find('Conv2d') != -1: if args.init_type == 'normal': nn.init.normal_(m.weight.data, 0.0, 0.02) elif args.init_type == 'orth': nn.init.orthogonal_(m.weight.data) elif args.init_type == 'xavier_uniform': nn.init.xavier_uniform(m.weight.data, 1.) else: raise NotImplementedError('{} unknown inital type'.format( args.init_type)) elif classname.find('BatchNorm2d') != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0.0) gen_net.apply(weights_init) dis_net.apply(weights_init) enc_net.apply(weights_init) ae_recon_optimizer = torch.optim.Adam( itertools.chain(enc_net.parameters(), gen_net.parameters()), args.ae_recon_lr, (args.beta1, args.beta2)) ae_reg_optimizer = torch.optim.Adam( itertools.chain(enc_net.parameters(), gen_net.parameters()), args.ae_reg_lr, (args.beta1, args.beta2)) dis_optimizer = torch.optim.Adam(dis_net.parameters(), args.d_lr, (args.beta1, args.beta2)) gen_optimizer = torch.optim.Adam(gen_net.parameters(), args.g_lr, (args.beta1, args.beta2)) dataset = datasets.ImageDataset(args) train_loader = dataset.train valid_loader = dataset.valid para_loader = pl.ParallelLoader(train_loader, [device]) fid_stat = str(pathlib.Path( __file__).parent.absolute()) + '/fid_stat/fid_stat_cifar10_test.npz' if not os.path.exists(fid_stat): download_stat_cifar10_test() is_best = True args.num_epochs = np.ceil(args.num_iter / len(train_loader)) gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0, args.num_iter / 2, args.num_iter) dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0, args.num_iter / 2, args.num_iter) ae_recon_scheduler = LinearLrDecay(ae_recon_optimizer, args.ae_recon_lr, 0, args.num_iter / 2, args.num_iter) ae_reg_scheduler = LinearLrDecay(ae_reg_optimizer, args.ae_reg_lr, 0, args.num_iter / 2, args.num_iter) # initial start_epoch = 0 best_fid = 1e4 # set writer if args.load_path: print(f'=> resuming from {args.load_path}') assert os.path.exists(args.load_path) checkpoint_file = os.path.join(args.load_path, 'Model', 'checkpoint.pth') assert os.path.exists(checkpoint_file) checkpoint = torch.load(checkpoint_file) start_epoch = checkpoint['epoch'] best_fid = checkpoint['best_fid'] gen_net.load_state_dict(checkpoint['gen_state_dict']) enc_net.load_state_dict(checkpoint['enc_state_dict']) dis_net.load_state_dict(checkpoint['dis_state_dict']) gen_optimizer.load_state_dict(checkpoint['gen_optimizer']) dis_optimizer.load_state_dict(checkpoint['dis_optimizer']) ae_recon_optimizer.load_state_dict(checkpoint['ae_recon_optimizer']) ae_reg_optimizer.load_state_dict(checkpoint['ae_reg_optimizer']) args.path_helper = checkpoint['path_helper'] logger = create_logger(args.path_helper['log_path']) logger.info( f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})') else: # create new log dir assert args.exp_name logs_dir = str(pathlib.Path(__file__).parent.parent) + '/logs' args.path_helper = set_log_dir(logs_dir, args.exp_name) logger = create_logger(args.path_helper['log_path']) logger.info(args) writer_dict = { 'writer': SummaryWriter(args.path_helper['log_path']), 'train_global_steps': start_epoch * len(train_loader), 'valid_global_steps': start_epoch // args.val_freq, } # train loop for epoch in tqdm(range(int(start_epoch), int(args.num_epochs)), desc='total progress'): lr_schedulers = (gen_scheduler, dis_scheduler, ae_recon_scheduler, ae_reg_scheduler) train(device, args, gen_net, dis_net, enc_net, gen_optimizer, dis_optimizer, ae_recon_optimizer, ae_reg_optimizer, para_loader, epoch, writer_dict, lr_schedulers) if epoch and epoch % args.val_freq == 0 or epoch == args.num_epochs - 1: fid_score = validate(args, fid_stat, gen_net, writer_dict, valid_loader) logger.info(f'FID score: {fid_score} || @ epoch {epoch}.') if fid_score < best_fid: best_fid = fid_score is_best = True else: is_best = False else: is_best = False save_checkpoint( { 'epoch': epoch + 1, 'gen_state_dict': gen_net.state_dict(), 'dis_state_dict': dis_net.state_dict(), 'enc_state_dict': enc_net.state_dict(), 'gen_optimizer': gen_optimizer.state_dict(), 'dis_optimizer': dis_optimizer.state_dict(), 'ae_recon_optimizer': ae_recon_optimizer.state_dict(), 'ae_reg_optimizer': ae_reg_optimizer.state_dict(), 'best_fid': best_fid, 'path_helper': args.path_helper }, is_best, args.path_helper['ckpt_path'])
import torchvision.transforms as transforms import tqdm from torch.utils.data import DataLoader from config import Config from models import Encoder, SiameseNetwork config = Config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(config) siamese_network = SiameseNetwork(config) if config.load_model: encoder.load_state_dict( torch.load(config.saved_models_folder + '/encoder_epoch25.pth')) siamese_network.load_state_dict( torch.load(config.saved_models_folder + '/siamese_network_epoch25.pth')) encoder.to(device) encoder.train() siamese_network.to(device) siamese_network.train() params = list(encoder.parameters()) + list(siamese_network.parameters()) optimizer = torch.optim.Adam(params, lr=config.lr, betas=(0.9, 0.999)) transform = transforms.Compose([
with open(word_map_file, 'r') as j: word_map = json.load(j) decoder = DecoderWithAttention(attention_dim=attention_dim, embed_dim=emb_dim, decoder_dim=decoder_dim, vocab_size=len(word_map), dropout=dropout) decoder.load_state_dict( torch.load('/scratch/scratch2/adsue/pretrained/decoder_dict.pkl')) decoder = decoder.to(device) decoder.eval() encoder = Encoder() encoder.load_state_dict( torch.load('/scratch/scratch2/adsue/pretrained/encoder_dict.pkl')) encoder = encoder.to(device) encoder.eval() ########################################################################################################################## imsize = 256 image_transform = transforms.Compose( [transforms.Scale(int(imsize * 76 / 64)), transforms.RandomCrop(imsize)]) norm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) folder = '/scratch/scratch2/adsue/caption_dataset/val2014/'
from config import Config from models import Encoder, SiameseNetwork batch_size = 8 threshold = 0.9 config = Config() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(config) siamese_network = SiameseNetwork(config) encoder.load_state_dict( torch.load(config.saved_models_folder + '/encoder_epoch500_loss0.0009.pth')) encoder.to(device) encoder.eval() siamese_network.load_state_dict( torch.load(config.saved_models_folder + '/siamese_network_epoch500_loss0.0009.pth')) siamese_network.to(device) siamese_network.eval() transform = transforms.Compose( [transforms.Grayscale(num_output_channels=1), transforms.ToTensor()]) train_data = torchvision.datasets.ImageFolder(config.data_folder, transform=transform)
cifar_10_train_dt = CIFAR10(r'~/.torch', download=True, transform=ToTensor()) cifar_10_train_l = DataLoader(cifar_10_train_dt, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=torch.cuda.is_available()) encoder = Encoder().to(device) loss_fn = DeepInfoMaxLoss().to(device) optim = Adam(encoder.parameters(), lr=1e-4) loss_optim = Adam(loss_fn.parameters(), lr=1e-4) epoch_restart = 20 root = Path(f"./Models/run1") if epoch_restart is not None and root is not None: enc_file = root / Path('encoder' + str(epoch_restart) + '.wgt') loss_file = root / Path('loss' + str(epoch_restart) + '.wgt') encoder.load_state_dict(torch.load(str(enc_file))) loss_fn.load_state_dict(torch.load(str(loss_file))) for epoch in range(epoch_restart + 1, 30): batch = tqdm(cifar_10_train_l, total=len(cifar_10_train_dt) // batch_size) train_loss = [] for x, target in batch: x = x.to(device) optim.zero_grad() loss_optim.zero_grad() y, M = encoder(x) # rotate images to create pairs for comparison M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0) # put the first one to the last loss = loss_fn(y, M, M_prime) train_loss.append(loss.item())
def main(args): device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"Training on {device}") if not os.path.exists(args.models_dir): os.makedirs(args.models_dir) if args.build_vocab: print( f"Building vocabulary from captions at {args.captions_json} and with count threshold={args.threshold}" ) vocab_object = build_vocab(args.captions_json, args.threshold) with open(args.vocab_path, "wb") as vocab_f: pickle.dump(vocab_object, vocab_f) print( f"Saved the vocabulary object to {args.vocab_path}, total size={len(vocab_object)}" ) else: with open(args.vocab_path, 'rb') as f: vocab_object = pickle.load(f) print( f"Loaded the vocabulary object from {args.vocab_path}, total size={len(vocab_object)}" ) if args.glove_embed_path is not None: with open(args.glove_embed_path, 'rb') as f: glove_embeddings = pickle.load(f) print( f"Loaded the glove embeddings from {args.glove_embed_path}, total size={len(glove_embeddings)}" ) # We are using 300d glove embeddings args.embed_size = 300 weights_matrix = np.zeros((len(vocab_object), args.embed_size)) for word, index in vocab_object.word2index.items(): if word in glove_embeddings: weights_matrix[index] = glove_embeddings[word] else: weights_matrix[index] = np.random.normal( scale=0.6, size=(args.embed_size, )) weights_matrix = torch.from_numpy(weights_matrix).float().to(device) else: weights_matrix = None img_transforms = transforms.Compose([ transforms.Resize((256, 256)), transforms.RandomCrop((224, 224)), transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) ]) train_dataset = cocoDataset(args.image_root, args.captions_json, vocab_object, img_transforms) train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn) encoder = Encoder(args.resnet_size, (3, 224, 224), args.embed_size).to(device) decoder = Decoder(args.rnn_type, weights_matrix, len(vocab_object), args.embed_size, args.hidden_size).to(device) encoder_learnable = list(encoder.linear.parameters()) decoder_learnable = list(decoder.rnn.parameters()) + list( decoder.linear.parameters()) if args.glove_embed_path is None: decoder_learnable = decoder_learnable + list( decoder.embedding.parameters()) criterion = nn.CrossEntropyLoss() params = encoder_learnable + decoder_learnable optimizer = torch.optim.Adam(params, lr=args.learning_rate) start_epoch = 0 if args.ckpt_path is not None: model_ckpt = torch.load(args.ckpt_path) start_epoch = model_ckpt['epoch'] + 1 prev_loss = model_ckpt['loss'] encoder.load_state_dict(model_ckpt['encoder']) decoder.load_state_dict(model_ckpt['decoder']) optimizer.load_state_dict(model_ckpt['optimizer']) print( f"Loaded model and optimizer state from {args.ckpt_path}; start epoch at {start_epoch}; prev loss={prev_loss}" ) total_examples = len(train_dataloader) for epoch in range(start_epoch, args.num_epochs): for i, (images, captions, lengths) in enumerate(train_dataloader): images = images.to(device) captions = captions.to(device) targets = pack_padded_sequence(captions, lengths, batch_first=True).data image_embeddings = encoder(images) outputs = decoder(image_embeddings, captions, lengths) loss = criterion(outputs, targets) decoder.zero_grad() encoder.zero_grad() loss.backward() optimizer.step() if i % args.log_interval == 0: loss_val = "{:.4f}".format(loss.item()) perplexity_val = "{:5.4f}".format(np.exp(loss.item())) print( f"epoch=[{epoch}/{args.num_epochs}], iteration=[{i}/{total_examples}], loss={loss_val}, perplexity={perplexity_val}" ) torch.save( { 'epoch': epoch, 'encoder': encoder.state_dict(), 'decoder': decoder.state_dict(), 'optimizer': optimizer.state_dict(), 'loss': loss }, os.path.join(args.models_dir, 'model-after-epoch-{}.ckpt'.format(epoch)))