tgt_dir = 'webcam'
test_dir = 'test'
cuda = torch.cuda.is_available()
test_loader = get_dataloader(data_dir, tgt_dir, batch_size=15, train=False)
# lam for confusion
lam = 0.01
# nu for soft
nu = 0.1

# load the pretrained and fine-tuned alex model

encoder = Encoder()
cl_classifier = ClassClassifier(num_classes=31)
dm_classifier = DomainClassifier()

encoder.load_state_dict(torch.load('./checkpoints/a2w/src_encoder_final.pth'))
cl_classifier.load_state_dict(
    torch.load('./checkpoints/a2w/src_classifier_final.pth'))

src_train_loader = get_dataloader(data_dir, src_dir, batch_size, train=True)
tgt_train_loader = get_dataloader(data_dir,
                                  tgt_train_dir,
                                  batch_size,
                                  train=True)
criterion = nn.CrossEntropyLoss()
# criterion_kl = nn.KLDivLoss()
if cuda:
    criterion = criterion.cuda()
    cl_classifier = cl_classifier.cuda()
    dm_classifier = dm_classifier.cuda()
    encoder = encoder.cuda()
Example #2
0
def main():

    parser = argparse.ArgumentParser(description='Test SimCLR model')
    parser.add_argument('--EPOCHS',
                        default=1,
                        type=int,
                        help='Number of epochs for training')
    parser.add_argument('--BATCH_SIZE',
                        default=64,
                        type=int,
                        help='Batch size')
    parser.add_argument(
        '--LOG_INT',
        type=int,
        default=100,
        help='how many batches to wait before logging training status')
    parser.add_argument('--SAVED_MODEL', default='./ckpt/model.pth')
    args = parser.parse_args()

    use_cuda = torch.cuda.is_available()

    device = torch.device("cuda" if use_cuda else "cpu")

    saved_model = Encoder()
    saved_model.load_state_dict(torch.load(args.SAVED_MODEL))
    # Freeze weights in the pretrained model
    for param in saved_model.parameters():
        param.requires_grad = False
    test_saved_model = Classifier(saved_model).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer_saved = optim.Adam(test_saved_model.fc.parameters())

    standard_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.4914, 0.4822, 0.4465],
                             [0.2023, 0.1994, 0.2010])
    ])

    trainset = torchvision.datasets.CIFAR10(root='./data',
                                            train=True,
                                            download=False,
                                            transform=standard_transform)

    train_loader = torch.utils.data.DataLoader(trainset,
                                               batch_size=args.BATCH_SIZE,
                                               shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data',
                                           train=False,
                                           download=False,
                                           transform=standard_transform)

    test_loader = torch.utils.data.DataLoader(testset,
                                              batch_size=args.BATCH_SIZE,
                                              shuffle=True)

    for epoch in range(args.EPOCHS):
        print("Performance on the saved model")
        train(args, test_saved_model, device, train_loader, optimizer_saved,
              epoch, criterion)
        test(args, test_saved_model, device, test_loader)
Example #3
0
class Dreamer(Agent):
    # The agent has its own replay buffer, update, act
    def __init__(self, args):
        """
    All paras are passed by args
    :param args: a dict that includes parameters
    """
        super().__init__()
        self.args = args
        # Initialise model parameters randomly
        self.transition_model = TransitionModel(
            args.belief_size, args.state_size, args.action_size,
            args.hidden_size, args.embedding_size,
            args.dense_act).to(device=args.device)

        self.observation_model = ObservationModel(
            args.symbolic,
            args.observation_size,
            args.belief_size,
            args.state_size,
            args.embedding_size,
            activation_function=(args.dense_act if args.symbolic else
                                 args.cnn_act)).to(device=args.device)

        self.reward_model = RewardModel(args.belief_size, args.state_size,
                                        args.hidden_size,
                                        args.dense_act).to(device=args.device)

        self.encoder = Encoder(args.symbolic, args.observation_size,
                               args.embedding_size,
                               args.cnn_act).to(device=args.device)

        self.actor_model = ActorModel(
            args.action_size,
            args.belief_size,
            args.state_size,
            args.hidden_size,
            activation_function=args.dense_act,
            fix_speed=args.fix_speed,
            throttle_base=args.throttle_base).to(device=args.device)

        self.value_model = ValueModel(args.belief_size, args.state_size,
                                      args.hidden_size,
                                      args.dense_act).to(device=args.device)

        self.value_model2 = ValueModel(args.belief_size, args.state_size,
                                       args.hidden_size,
                                       args.dense_act).to(device=args.device)

        self.pcont_model = PCONTModel(args.belief_size, args.state_size,
                                      args.hidden_size,
                                      args.dense_act).to(device=args.device)

        self.target_value_model = deepcopy(self.value_model)
        self.target_value_model2 = deepcopy(self.value_model2)

        for p in self.target_value_model.parameters():
            p.requires_grad = False
        for p in self.target_value_model2.parameters():
            p.requires_grad = False

        # setup the paras to update
        self.world_param = list(self.transition_model.parameters())\
                          + list(self.observation_model.parameters())\
                          + list(self.reward_model.parameters())\
                          + list(self.encoder.parameters())
        if args.pcont:
            self.world_param += list(self.pcont_model.parameters())

        # setup optimizer
        self.world_optimizer = optim.Adam(self.world_param, lr=args.world_lr)
        self.actor_optimizer = optim.Adam(self.actor_model.parameters(),
                                          lr=args.actor_lr)
        self.value_optimizer = optim.Adam(list(self.value_model.parameters()) +
                                          list(self.value_model2.parameters()),
                                          lr=args.value_lr)

        # setup the free_nat to
        self.free_nats = torch.full(
            (1, ), args.free_nats, dtype=torch.float32,
            device=args.device)  # Allowed deviation in KL divergence

        # TODO: change it to the new replay buffer, in buffer.py
        self.D = ExperienceReplay(args.experience_size, args.symbolic,
                                  args.observation_size, args.action_size,
                                  args.bit_depth, args.device)

        if self.args.auto_temp:
            # setup for learning of alpha term (temp of the entropy term)
            self.log_temp = torch.zeros(1,
                                        requires_grad=True,
                                        device=args.device)
            self.target_entropy = -np.prod(
                args.action_size if not args.fix_speed else self.args.
                action_size - 1).item()  # heuristic value from SAC paper
            self.temp_optimizer = optim.Adam(
                [self.log_temp], lr=args.value_lr)  # use the same value_lr

        # TODO: print out the param used in Dreamer
        # var_counts = tuple(count_vars(module) for module in [self., self.ac.q1, self.ac.q2])
        # print('\nNumber of parameters: \t pi: %d, \t q1: %d, \t q2: %d\n' % var_counts)

    # def process_im(self, image, image_size=None, rgb=None):
    #   # Resize, put channel first, convert it to a tensor, centre it to [-0.5, 0.5] and add batch dimenstion.
    #
    #   def preprocess_observation_(observation, bit_depth):
    #     # Preprocesses an observation inplace (from float32 Tensor [0, 255] to [-0.5, 0.5])
    #     observation.div_(2 ** (8 - bit_depth)).floor_().div_(2 ** bit_depth).sub_(
    #       0.5)  # Quantise to given bit depth and centre
    #     observation.add_(torch.rand_like(observation).div_(
    #       2 ** bit_depth))  # Dequantise (to approx. match likelihood of PDF of continuous images vs. PMF of discrete images)
    #
    #   image = image[40:, :, :]  # clip the above 40 rows
    #   image = torch.tensor(cv2.resize(image, (40, 40), interpolation=cv2.INTER_LINEAR).transpose(2, 0, 1),
    #                         dtype=torch.float32)  # Resize and put channel first
    #
    #   preprocess_observation_(image, self.args.bit_depth)
    #   return image.unsqueeze(dim=0)
    def process_im(self, images, image_size=None, rgb=None):
        images = cv2.resize(images, (40, 40))
        images = np.dot(images, [0.299, 0.587, 0.114])
        obs = torch.tensor(images,
                           dtype=torch.float32).div_(255.).sub_(0.5).unsqueeze(
                               dim=0)  # shape [1, 40, 40], range:[-0.5,0.5]
        return obs.unsqueeze(dim=0)  # add batch dimension

    def append_buffer(self, new_traj):
        # append new collected trajectory, not implement the data augmentation
        # shape of new_traj: [(o, a, r, d) * steps]
        for state in new_traj:
            observation, action, reward, done = state
            self.D.append(observation, action.cpu(), reward, done)

    def _compute_loss_world(self, state, data):
        # unpackage data
        beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = state
        observations, rewards, nonterminals = data

        # observation_loss = F.mse_loss(
        #   bottle(self.observation_model, (beliefs, posterior_states)),
        #   observations[1:],
        #   reduction='none').sum(dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))
        #
        # reward_loss = F.mse_loss(
        #   bottle(self.reward_model, (beliefs, posterior_states)),
        #   rewards[1:],
        #   reduction='none').mean(dim=(0,1))

        observation_loss = F.mse_loss(
            bottle(self.observation_model, (beliefs, posterior_states)),
            observations,
            reduction='none').sum(
                dim=2 if self.args.symbolic else (2, 3, 4)).mean(dim=(0, 1))

        reward_loss = F.mse_loss(bottle(self.reward_model,
                                        (beliefs, posterior_states)),
                                 rewards,
                                 reduction='none').mean(dim=(0, 1))  # TODO: 5

        # transition loss
        kl_loss = torch.max(
            kl_divergence(
                Independent(Normal(posterior_means, posterior_std_devs), 1),
                Independent(Normal(prior_means, prior_std_devs), 1)),
            self.free_nats).mean(dim=(0, 1))

        # print("check the reward", bottle(pcont_model, (beliefs, posterior_states)).shape, nonterminals[:-1].shape)
        if self.args.pcont:
            pcont_loss = F.binary_cross_entropy(
                bottle(self.pcont_model, (beliefs, posterior_states)),
                nonterminals)
            # pcont_pred = torch.distributions.Bernoulli(logits=bottle(self.pcont_model, (beliefs, posterior_states)))
            # pcont_loss = -pcont_pred.log_prob(nonterminals[1:]).mean(dim=(0, 1))

        return observation_loss, self.args.reward_scale * reward_loss, kl_loss, (
            self.args.pcont_scale * pcont_loss if self.args.pcont else 0)

    def _compute_loss_actor(self,
                            imag_beliefs,
                            imag_states,
                            imag_ac_logps=None):
        # reward and value prediction of imagined trajectories
        imag_rewards = bottle(self.reward_model, (imag_beliefs, imag_states))
        imag_values = bottle(self.value_model, (imag_beliefs, imag_states))
        imag_values2 = bottle(self.value_model2, (imag_beliefs, imag_states))
        imag_values = torch.min(imag_values, imag_values2)

        with torch.no_grad():
            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)
        pcont = pcont.detach()

        if imag_ac_logps is not None:
            imag_values[
                1:] -= self.args.temp * imag_ac_logps  # add entropy here

        returns = cal_returns(imag_rewards[:-1],
                              imag_values[:-1],
                              imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)

        discount = torch.cumprod(
            torch.cat([torch.ones_like(pcont[:1]), pcont[:-2]], 0), 0)
        discount = discount.detach()

        assert list(discount.size()) == list(returns.size())
        actor_loss = -torch.mean(discount * returns)
        return actor_loss

    def _compute_loss_critic(self,
                             imag_beliefs,
                             imag_states,
                             imag_ac_logps=None):

        with torch.no_grad():
            # calculate the target with the target nn
            target_imag_values = bottle(self.target_value_model,
                                        (imag_beliefs, imag_states))
            target_imag_values2 = bottle(self.target_value_model2,
                                         (imag_beliefs, imag_states))
            target_imag_values = torch.min(target_imag_values,
                                           target_imag_values2)
            imag_rewards = bottle(self.reward_model,
                                  (imag_beliefs, imag_states))

            if self.args.pcont:
                pcont = bottle(self.pcont_model, (imag_beliefs, imag_states))
            else:
                pcont = self.args.discount * torch.ones_like(imag_rewards)

        # print("check pcont", pcont)
            if imag_ac_logps is not None:
                target_imag_values[1:] -= self.args.temp * imag_ac_logps

        returns = cal_returns(imag_rewards[:-1],
                              target_imag_values[:-1],
                              target_imag_values[-1],
                              pcont[:-1],
                              lambda_=self.args.disclam)
        target_return = returns.detach()

        value_pred = bottle(self.value_model, (imag_beliefs, imag_states))[:-1]
        value_pred2 = bottle(self.value_model2,
                             (imag_beliefs, imag_states))[:-1]

        value_loss = F.mse_loss(value_pred, target_return,
                                reduction="none").mean(dim=(0, 1))
        value_loss2 = F.mse_loss(value_pred2, target_return,
                                 reduction="none").mean(dim=(0, 1))
        value_loss += value_loss2

        return value_loss

    def _latent_imagination(self,
                            beliefs,
                            posterior_states,
                            with_logprob=False):
        # Rollout to generate imagined trajectories

        chunk_size, batch_size, _ = list(
            posterior_states.size())  # flatten the tensor
        flatten_size = chunk_size * batch_size

        posterior_states = posterior_states.detach().reshape(flatten_size, -1)
        beliefs = beliefs.detach().reshape(flatten_size, -1)

        imag_beliefs, imag_states, imag_ac_logps = [beliefs
                                                    ], [posterior_states], []

        for i in range(self.args.planning_horizon):
            imag_action, imag_ac_logp = self.actor_model(
                imag_beliefs[-1].detach(),
                imag_states[-1].detach(),
                deterministic=False,
                with_logprob=with_logprob,
            )
            imag_action = imag_action.unsqueeze(dim=0)  # add time dim

            # print(imag_states[-1].shape, imag_action.shape, imag_beliefs[-1].shape)
            imag_belief, imag_state, _, _ = self.transition_model(
                imag_states[-1], imag_action, imag_beliefs[-1])
            imag_beliefs.append(imag_belief.squeeze(dim=0))
            imag_states.append(imag_state.squeeze(dim=0))
            if with_logprob:
                imag_ac_logps.append(imag_ac_logp.squeeze(dim=0))

        imag_beliefs = torch.stack(imag_beliefs, dim=0).to(
            self.args.device
        )  # shape [horizon+1, (chuck-1)*batch, belief_size]
        imag_states = torch.stack(imag_states, dim=0).to(self.args.device)
        if with_logprob:
            imag_ac_logps = torch.stack(imag_ac_logps, dim=0).to(
                self.args.device)  # shape [horizon, (chuck-1)*batch]

        return imag_beliefs, imag_states, imag_ac_logps if with_logprob else None

    def update_parameters(self, gradient_steps):
        loss_info = []  # used to record loss
        for s in tqdm(range(gradient_steps)):
            # get state and belief of samples
            observations, actions, rewards, nonterminals = self.D.sample(
                self.args.batch_size, self.args.chunk_size)
            # print("check sampled rewrads", rewards)
            init_belief = torch.zeros(self.args.batch_size,
                                      self.args.belief_size,
                                      device=self.args.device)
            init_state = torch.zeros(self.args.batch_size,
                                     self.args.state_size,
                                     device=self.args.device)

            # Update belief/state using posterior from previous belief/state, previous action and current observation (over entire sequence at once)
            # beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
            #   init_state,
            #   actions[:-1],
            #   init_belief,
            #   bottle(self.encoder, (observations[1:], )),
            #   nonterminals[:-1])

            beliefs, prior_states, prior_means, prior_std_devs, posterior_states, posterior_means, posterior_std_devs = self.transition_model(
                init_state, actions, init_belief,
                bottle(self.encoder, (observations, )),
                nonterminals)  # TODO: 4

            # update paras of world model
            world_model_loss = self._compute_loss_world(
                state=(beliefs, prior_states, prior_means, prior_std_devs,
                       posterior_states, posterior_means, posterior_std_devs),
                data=(observations, rewards, nonterminals))
            observation_loss, reward_loss, kl_loss, pcont_loss = world_model_loss
            self.world_optimizer.zero_grad()
            (observation_loss + reward_loss + kl_loss + pcont_loss).backward()
            nn.utils.clip_grad_norm_(self.world_param,
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.world_optimizer.step()

            # freeze params to save memory
            for p in self.world_param:
                p.requires_grad = False
            for p in self.value_model.parameters():
                p.requires_grad = False
            for p in self.value_model2.parameters():
                p.requires_gard = False

            # latent imagination
            imag_beliefs, imag_states, imag_ac_logps = self._latent_imagination(
                beliefs, posterior_states, with_logprob=self.args.with_logprob)

            # update temp
            if self.args.auto_temp:
                temp_loss = -(
                    self.log_temp *
                    (imag_ac_logps[0] + self.target_entropy).detach()).mean()
                self.temp_optimizer.zero_grad()
                temp_loss.backward()
                self.temp_optimizer.step()
                self.args.temp = self.log_temp.exp()

            # update actor
            actor_loss = self._compute_loss_actor(imag_beliefs,
                                                  imag_states,
                                                  imag_ac_logps=imag_ac_logps)

            self.actor_optimizer.zero_grad()
            actor_loss.backward()
            nn.utils.clip_grad_norm_(self.actor_model.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.actor_optimizer.step()

            for p in self.world_param:
                p.requires_grad = True
            for p in self.value_model.parameters():
                p.requires_grad = True
            for p in self.value_model2.parameters():
                p.requires_grad = True

            # update critic
            imag_beliefs = imag_beliefs.detach()
            imag_states = imag_states.detach()

            critic_loss = self._compute_loss_critic(
                imag_beliefs, imag_states, imag_ac_logps=imag_ac_logps)

            self.value_optimizer.zero_grad()
            critic_loss.backward()
            nn.utils.clip_grad_norm_(self.value_model.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            nn.utils.clip_grad_norm_(self.value_model2.parameters(),
                                     self.args.grad_clip_norm,
                                     norm_type=2)
            self.value_optimizer.step()

            loss_info.append([
                observation_loss.item(),
                reward_loss.item(),
                kl_loss.item(),
                pcont_loss.item() if self.args.pcont else 0,
                actor_loss.item(),
                critic_loss.item()
            ])

        # finally, update target value function every #gradient_steps
        with torch.no_grad():
            self.target_value_model.load_state_dict(
                self.value_model.state_dict())
        with torch.no_grad():
            self.target_value_model2.load_state_dict(
                self.value_model2.state_dict())

        return loss_info

    def infer_state(self, observation, action, belief=None, state=None):
        """ Infer belief over current state q(s_t|o≤t,a<t) from the history,
        return updated belief and posterior_state at time t
        returned shape: belief/state [belief/state_dim] (remove the time_dim)
    """
        # observation is obs.to(device), action.shape=[act_dim] (will add time dim inside this fn), belief.shape
        belief, _, _, _, posterior_state, _, _ = self.transition_model(
            state, action.unsqueeze(dim=0), belief,
            self.encoder(observation).unsqueeze(
                dim=0))  # Action and observation need extra time dimension

        belief, posterior_state = belief.squeeze(
            dim=0), posterior_state.squeeze(
                dim=0)  # Remove time dimension from belief/state

        return belief, posterior_state

    def select_action(self, state, deterministic=False):
        # get action with the inputs get from fn: infer_state; return a numpy with shape [batch, act_size]
        belief, posterior_state = state
        action, _ = self.actor_model(belief,
                                     posterior_state,
                                     deterministic=deterministic,
                                     with_logprob=False)
        if not deterministic and not self.args.with_logprob:
            print("e")
            action = Normal(action, self.args.expl_amount).rsample()

            # clip the angle
            action[:, 0].clamp_(min=self.args.angle_min,
                                max=self.args.angle_max)
            # clip the throttle
            if self.args.fix_speed:
                action[:, 1] = self.args.throttle_base
            else:
                action[:, 1].clamp_(min=self.args.throttle_min,
                                    max=self.args.throttle_max)
        print("action", action)
        # return action.cup().numpy()
        return action  # this is a Tonsor.cuda

    def import_parameters(self, params):
        # only import or export the parameters used when local rollout
        self.encoder.load_state_dict(params["encoder"])
        self.actor_model.load_state_dict(params["policy"])
        self.transition_model.load_state_dict(params["transition"])

    def export_parameters(self):
        """ return the model paras used for local rollout """
        params = {
            "encoder": self.encoder.cpu().state_dict(),
            "policy": self.actor_model.cpu().state_dict(),
            "transition": self.transition_model.cpu().state_dict()
        }

        self.encoder.to(self.args.device)
        self.actor_model.to(self.args.device)
        self.transition_model.to(self.args.device)

        return params
                        type=float,
                        nargs='+',
                        help=' x, y\
                        values from the disruption to decodding')
    return parser.parse_args()


args = manage()

cuda = args.cuda and torch.cuda.is_available()

encoder = Encoder()
decoder = Decoder()

if cuda:
    encoder.load_state_dict(torch.load(args.encoder_path))
    encoder.cuda()
else:
    encoder.load_state_dict(torch.load(args.encoder_path, map_location='cpu'))

if cuda:
    decoder.load_state_dict(torch.load(args.decoder_path))
    decoder.cuda()
else:
    decoder.load_state_dict(torch.load(args.decoder_path, map_location='cpu'))

if args.image_path:
    image = cv2.imread(args.image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    code = encoder(transforms.ToTensor()(transforms.ToPILImage()(image)).view(
        -1, 784))
Example #5
0
    metrics['episodes'].append(s)


# Initialise model parameters randomly
transition_model = TransitionModel(args.belief_size, args.state_size, env.action_size, args.hidden_size, args.embedding_size, args.activation_function).to(device=args.device)
observation_model = ObservationModel(args.symbolic_env, env.observation_size, args.belief_size, args.state_size, args.embedding_size, args.activation_function).to(device=args.device)
reward_model = RewardModel(args.belief_size, args.state_size, args.hidden_size, args.activation_function).to(device=args.device)
encoder = Encoder(args.symbolic_env, env.observation_size, args.embedding_size, args.activation_function).to(device=args.device)
param_list = list(transition_model.parameters()) + list(observation_model.parameters()) + list(reward_model.parameters()) + list(encoder.parameters())
optimiser = optim.Adam(param_list, lr=0 if args.learning_rate_schedule != 0 else args.learning_rate, eps=args.adam_epsilon)
if args.load_checkpoint > 0:
  model_dicts = torch.load(os.path.join(results_dir, 'models_%d.pth' % args.load_checkpoint))
  transition_model.load_state_dict(model_dicts['transition_model'])
  observation_model.load_state_dict(model_dicts['observation_model'])
  reward_model.load_state_dict(model_dicts['reward_model'])
  encoder.load_state_dict(model_dicts['encoder'])
  optimiser.load_state_dict(model_dicts['optimiser'])
planner = MPCPlanner(env.action_size, args.planning_horizon, args.optimisation_iters, args.candidates, args.top_candidates, transition_model, reward_model)
global_prior = Normal(torch.zeros(args.batch_size, args.state_size, device=args.device), torch.ones(args.batch_size, args.state_size, device=args.device))  # Global prior N(0, I)
free_nats = torch.full((1, ), args.free_nats, device=args.device)  # Allowed deviation in KL divergence


def update_belief_and_act(args, env, planner, transition_model, encoder, belief, posterior_state, action, observation, test):
  # Infer belief over current state q(s_t|o≤t,a<t) from the history
  belief, _, _, _, posterior_state, _, _ = transition_model(posterior_state, action.unsqueeze(dim=0), belief, encoder(observation).unsqueeze(dim=0))  # Action and observation need extra time dimension
  belief, posterior_state = belief.squeeze(dim=0), posterior_state.squeeze(dim=0)  # Remove time dimension from belief/state
  action = planner(belief, posterior_state)  # Get action from planner(q(s_t|o≤t,a<t), p)
  if not test:
    action = action + args.action_noise * torch.randn_like(action)  # Add exploration noise ε ~ p(ε) to the action
  next_observation, reward, done = env.step(action.cpu() if isinstance(env, EnvBatcher) else action[0].cpu())  # Perform environment step (action repeats handled internally)
  return belief, posterior_state, action, next_observation, reward, done
# model
# Pretrained Model
alexnet = torchvision.models.alexnet(pretrained=True)
pretrained_dict = alexnet.state_dict()
# Train source data
# Model parameters
src_encoder = Encoder()
src_classifier = ClassClassifier(num_classes=31)
src_encoder_dict = src_encoder.state_dict()
# Load pretrained model 
# filter out unnecessary keys
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in src_encoder_dict}
# overwrite entries in the existing state dict
src_encoder_dict.update(pretrained_dict) 
# load the new state dict
src_encoder.load_state_dict(src_encoder_dict)
optimizer = optim.SGD(
    list(src_encoder.parameters()) + list(src_classifier.parameters()),
    lr=lr,
    momentum=momentum,
    weight_decay=weight_decay)

criterion = nn.CrossEntropyLoss()

if cuda: 
    src_encoder = src_encoder.cuda()
    src_classifier = src_classifier.cuda() 
    criterion = criterion.cuda() 

src_encoder.train()
src_classifier.train()
def main():

    print('Training parameters Initialized')
    training_parameters = TrainingParameters( start_epoch = 0,
                                            epochs = 120,  # number of epochs to train for
                                            epochs_since_improvement = 0,  # Epochs since improvement in BLEU score
                                            batch_size = 32,
                                            workers = 1,  # for data-loading; right now, only 1 works with h5py
                                            fine_tune_encoder = True,  # fine-tune encoder
                                            encoder_lr = 1e-4,  # learning rate for encoder, if fine-tuning is used
                                            decoder_lr = 4e-4,  # learning rate for decoder
                                            grad_clip = 5.0,  # clip gradients at an absolute value of
                                            alpha_c = 1.0,  # regularization parameter for 'doubly stochastic attention'
                                            best_bleu4 = 0.0,  # BLEU-4 score right now
                                            print_freq = 100,  # print training/validation stats every __ batches
                                            checkpoint =  './Result/BEST_checkpoint_flickr8k_5_captions_per_image_5_minimum_word_frequency.pth.tar' # path to checkpoint, None if none
                                            # checkpoint = None
                                          )

    print('Loading Word-Map')
    word_map_file = os.path.join(data_folder,'WORDMAP_' + data_name + '.json')
    with open(word_map_file, 'r') as j:
        word_map = json.load(j)

    print('Creating Model')

    if training_parameters.checkpoint is None:
        encoder = Encoder()
        encoder.fine_tune(training_parameters.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()),
                                                lr=training_parameters.encoder_lr) if training_parameters.fine_tune_encoder else None
        
        decoder = Decoder(attention_dimension = attention_dimension,
                            embedding_dimension = embedding_dimension,
                            hidden_dimension = hidden_dimension,
                            vocab_size = len(word_map),
                            device = device,
                            dropout = dropout)                            
        decoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, decoder.parameters()),
                                                lr=training_parameters.decoder_lr)

    else:
        checkpoint = torch.load(training_parameters.checkpoint)
        training_parameters.start_epoch = checkpoint['epoch'] + 1
        training_parameters.epochs_since_improvement = checkpoint['epochs_since_improvement']
        training_parameters.best_bleu4 = checkpoint['bleu4']

        encoder = Encoder()
        encoder.load_state_dict(checkpoint['encoder_state_dict'])
        encoder_optimizer = checkpoint['encoder_optimizer']

        decoder = Decoder(attention_dimension = attention_dimension,
                            embedding_dimension = embedding_dimension,
                            hidden_dimension = hidden_dimension,
                            vocab_size = len(word_map),
                            device = device,
                            dropout = dropout)
        decoder.load_state_dict(checkpoint['decoder_state_dict'])
        decoder_optimizer = checkpoint['decoder_optimizer']

        if training_parameters.fine_tune_encoder is True and encoder_optimizer is None:
            encoder.fine_tune(training_parameters.fine_tune_encoder)
        encoder_optimizer = torch.optim.Adam(params=filter(lambda p : p.requires_grad, encoder.parameters()),
                                                lr=training_parameters.encoder_lr)

    encoder.to(device)
    decoder.to(device)

    criterion = nn.CrossEntropyLoss().to(device)
        
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    print('Creating Data Loaders')
    train_dataloader = torch.utils.data.DataLoader(
                                    CaptionDataset(data_folder, data_name, 'TRAIN', transform=transforms.Compose([normalize])),
                                    batch_size=training_parameters.batch_size, shuffle=True)
    
    validation_dataloader = torch.utils.data.DataLoader(
                                    CaptionDataset(data_folder, data_name, 'VALID', transform=transforms.Compose([normalize])),
                                    batch_size=training_parameters.batch_size, shuffle=True, pin_memory=True)

    for epoch in range(training_parameters.start_epoch, training_parameters.epochs):

        if training_parameters.epochs_since_improvement == 20:
            break
        if training_parameters.epochs_since_improvement > 0  and training_parameters.epochs_since_improvement % 8 == 0:
            adjust_learning_rate(decoder_optimizer, 0.8)
            if training_parameters.fine_tune_encoder:
                adjust_learning_rate(encoder_optimizer, 0.8)

        train(train_loader = train_dataloader,
              encoder = encoder,
              decoder = decoder,
              criterion = criterion,
              encoder_optimizer = encoder_optimizer,
              decoder_optimizer = decoder_optimizer,
              epoch = epoch,
              device = device,
              training_parameters = training_parameters)

        recent_bleu4_score = validate(validation_loader = validation_dataloader,
                                    encoder = encoder,
                                    decoder = decoder,
                                    criterion = criterion,
                                    word_map = word_map,
                                    device = device,
                                    training_parameters = training_parameters)

        is_best_score = recent_bleu4_score > training_parameters.best_bleu4
        training_parameters.best_bleu4 = max(recent_bleu4_score, training_parameters.best_bleu4)
        if not is_best_score:
            training_parameters.epochs_since_improvement += 1
            print('\nEpochs since last improvement : %d\n' % (training_parameters.epochs_since_improvement))
        else:
            training_parameters.epochs_since_improvement = 0
        
        save_checkpoint(data_name, epoch, training_parameters.epochs_since_improvement, encoder, decoder,
                        encoder_optimizer, decoder_optimizer, recent_bleu4_score, is_best_score)
Example #8
0
data_name = 'flickr8k_5_cap_per_img_5_min_word_freq'
word_map_file = 'datasets/caption_data/WORDMAP_' + data_name + '.json'
checkpoint = 'logs/tmp/BEST_MODEL.pth.tar'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

with open(word_map_file, 'r') as f:
    word_map = json.load(f)
rev_word_map = {v: k for k, v in word_map.items()}
vocab_size = len(word_map)

# 载入模型
encoder = Encoder()
decoder = DecoderWithAttention(512, 512, 512, vocab_size)
checkpoint = torch.load(checkpoint)
encoder.load_state_dict(checkpoint['encoder'])
decoder.load_state_dict(checkpoint['decoder'])
encoder.to(device)
decoder.to(device)
encoder.eval()
decoder.eval()


preprocess = T.Compose([
    T.Resize(size=(256, 256)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


def get_image_caption(ori_img):
Example #9
0
        "WARNING: You have a CUDA device, so you should probably run with --cuda"
    )

###### Definition of variables ######
# Networks
encoder = Encoder(input_nc=input_nc)
decoder_A2B = Decoder(output_nc=output_nc)
decoder_B2A = Decoder(output_nc=output_nc)

if activate_cuda:
    encoder.cuda()
    decoder_A2B.cuda()
    decoder_B2A.cuda()

# Load state dicts
encoder.load_state_dict(torch.load(encoder_path))
decoder_A2B.load_state_dict(torch.load(decoder_A2B_path))
decoder_B2A.load_state_dict(torch.load(decoder_B2A_path))

# Set model's test mode
encoder.eval()
decoder_A2B.eval()
decoder_B2A.eval()

# Inputs & targets memory allocation
Tensor = torch.cuda.FloatTensor if activate_cuda else torch.Tensor
input_A = Tensor(batch_size, input_nc, image_size, image_size)
input_B = Tensor(batch_size, output_nc, image_size, image_size)

# Dataset loader
transforms_ = [
Example #10
0
def make_representations(arguments, device):
    """
    Creates representations for all data.
    :param arguments: Dictionary containing arguments.
    :param device: PyTorch device object.
    """

    # Loads training and testing data.
    train_data = Dataset(arguments, "train")
    test_data = Dataset(arguments, "test")

    # Creates the data loaders for the training and testing data.
    training_data_loader = DataLoader(train_data,
                                      batch_size=arguments["batch_size"],
                                      shuffle=False,
                                      num_workers=arguments["data_workers"],
                                      pin_memory=False,
                                      drop_last=False)
    testing_data_loader = DataLoader(test_data,
                                     batch_size=arguments["batch_size"],
                                     shuffle=False,
                                     num_workers=arguments["data_workers"],
                                     pin_memory=False,
                                     drop_last=False)

    log(arguments, "Loaded Datasets")

    # Initialises the encoder.
    encoder = Encoder(0, arguments["image_size"],
                      arguments["pretrained"] == "imagenet")

    # Loads weights from pretrained Contrastive Predictive Coding model.
    if arguments["pretrained"].lower() == "cpc":
        encoder_path = os.path.join(
            arguments["model_dir"],
            f"{arguments['experiment']}_encoder_best.pt")
        encoder.load_state_dict(torch.load(encoder_path, map_location=device),
                                strict=False)

    # Sets the model to evaluation mode.
    encoder.eval()

    # Moves the model to the selected device.
    encoder.to(device)

    # If 16 bit precision is being used change the model and optimiser precision.
    if arguments["precision"] == 16:
        encoder = amp.initialize(encoder, opt_level="O2", verbosity=False)

    # Checks if precision level is supported and if not defaults to 32.
    elif arguments["precision"] != 32:
        log(
            arguments,
            "Only 16 and 32 bit precision supported. Defaulting to 32 bit precision."
        )

    log(arguments, "Models Initialised")

    # Creates a folder if one does not exist.
    os.makedirs(os.path.dirname(arguments["representation_dir"]),
                exist_ok=True)

    # Creates the HDF5 files used to store the training and testing data representations.
    train_representations = HDF5Handler(
        os.path.join(arguments["representation_dir"],
                     f"{arguments['experiment']}_train.h5"), 'x',
        (encoder.encoder_size, ))
    test_representations = HDF5Handler(
        os.path.join(arguments["representation_dir"],
                     f"{arguments['experiment']}_test.h5"), 'x',
        (encoder.encoder_size, ))

    log(arguments, "HDF5 Representation Files Created.")

    # Starts a timer.
    start_time = time.time()

    # Performs a representation generation with no gradients.
    with torch.no_grad():

        # Loops through the training data.
        num_batches = 0
        for images, _ in training_data_loader:

            # Loads the image batch into memory.
            images = images.to(device)

            # Gets the representations from of the image batch from the encoder.
            representations = encoder.forward_features(images)

            # Moves the representations to the CPU.
            representations = representations.cpu().data.numpy()

            # Adds the batch representations to the HDF5 file.
            train_representations.append(representations)

            # Prints information about representation extraction process.
            num_batches += 1
            if num_batches % arguments["log_intervals"] == 0:
                print(
                    f"Training Batches: {num_batches}/{len(train_data) // arguments['batch_size']}"
                )

        # Loops through the testing data.
        num_batches = 0
        for images, _ in testing_data_loader:

            # Loads the image batch into memory.
            images = images.to(device)

            # Gets the representations from of the image batch from the encoder.
            representations = encoder.forward_features(images)

            # Moves the representations to the CPU.
            representations = representations.cpu().data.numpy()

            # Adds the batch representations to the HDF5 file.
            test_representations.append(representations)

            # Prints information about representation extraction process.
            num_batches += 1
            if num_batches % arguments["log_intervals"] == 0:
                print(
                    f"Testing Batches: {num_batches}/{len(test_data) // arguments['batch_size']}"
                )

    print(
        f"Representations from {arguments['experiment']} encoder created in {int(time.time() - start_time)}s"
    )
Example #11
0
opt_encoder = torch.optim.Adam(encoder.parameters(), lr=knobs["lr_encoder"])
opt_decoder = torch.optim.Adam(decoder.parameters(), lr=knobs["lr_decoder"])

collector_reconstruction_loss = Collector()
collector_wasserstein_penalty = Collector()
collector_fooling_term = Collector()
collector_codes_min = Collector()
collector_codes_max = Collector()
if knobs["resume"]:
    writer = SummaryWriter(log_dir_last_modified)
    checkpoint_dir = checkpoints_dir_last_modified
    checkpoint = torch.load(checkpoint_dir)
    starting_epoch = checkpoint["epoch"]
    iteration = checkpoint["iteration"]
    encoder.load_state_dict(checkpoint["encoder_state_dict"])
    decoder.load_state_dict(checkpoint["decoder_state_dict"])
    opt_encoder.load_state_dict(checkpoint["opt_encoder_state_dict"])
    opt_decoder.load_state_dict(checkpoint["opt_decoder_state_dict"])
else:
    writer = SummaryWriter(log_dir_local_time)
    checkpoint_dir = checkpoints_dir_local_time
    starting_epoch = 1
    iteration = 0

encoder.train()
decoder.train()
for epoch in range(starting_epoch, knobs["num_epochs"] + 1):
    for batch in loader:
        iteration += 1
Example #12
0
def train_CIFAR10(opt):

    import torchvision.datasets as datasets
    import torchvision.transforms as transforms
    from torchvision.utils import make_grid
    from matplotlib import pyplot as plt
    params = get_config(opt.config)

    save_path = os.path.join(
        params['save_path'],
        datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S'))
    os.makedirs(save_path, exist_ok=True)
    shutil.copy('models.py', os.path.join(save_path, 'models.py'))
    shutil.copy('train.py', os.path.join(save_path, 'train.py'))
    shutil.copy(opt.config,
                os.path.join(save_path, os.path.basename(opt.config)))

    cuda = torch.cuda.is_available()
    gpu_ids = [i for i in range(torch.cuda.device_count())]

    TensorType = torch.cuda.FloatTensor if cuda else torch.Tensor

    data_path = os.path.join(params['data_root'], 'cifar10')

    os.makedirs(data_path, exist_ok=True)

    train_dataset = datasets.CIFAR10(root=data_path,
                                     train=True,
                                     download=True,
                                     transform=transforms.Compose([
                                         transforms.ToTensor(),
                                         transforms.Normalize((0.5, 0.5, 0.5),
                                                              (0.5, 0.5, 0.5))
                                     ]))

    val_dataset = datasets.CIFAR10(root=data_path,
                                   train=False,
                                   download=True,
                                   transform=transforms.Compose([
                                       transforms.ToTensor(),
                                       transforms.Normalize((0.5, 0.5, 0.5),
                                                            (0.5, 0.5, 0.5))
                                   ]))

    train_loader = DataLoader(train_dataset,
                              batch_size=params['batch_size'] * len(gpu_ids),
                              shuffle=True,
                              num_workers=params['num_workers'],
                              pin_memory=cuda)
    val_loader = DataLoader(val_dataset,
                            batch_size=1,
                            num_workers=params['num_workers'],
                            pin_memory=cuda)

    data_variance = np.var(train_dataset.train_data / 255.0)

    encoder = Encoder(params['dim'], params['residual_channels'],
                      params['n_layers'], params['d'])
    decoder = Decoder(params['dim'], params['residual_channels'],
                      params['n_layers'], params['d'])

    vq = VectorQuantizer(params['k'], params['d'], params['beta'],
                         params['decay'], TensorType)

    if params['checkpoint'] != None:
        checkpoint = torch.load(params['checkpoint'])

        params['start_epoch'] = checkpoint['epoch']
        encoder.load_state_dict(checkpoint['encoder'])
        decoder.load_state_dict(checkpoint['decoder'])
        vq.load_state_dict(checkpoint['vq'])

    model = VQVAE(encoder, decoder, vq)

    if cuda:
        model = nn.DataParallel(model.cuda(), device_ids=gpu_ids)

    parameters = list(model.parameters())
    opt = torch.optim.Adam([p for p in parameters if p.requires_grad],
                           lr=params['lr'])

    for epoch in range(params['start_epoch'], params['num_epochs']):
        train_bar = tqdm(train_loader)
        for data, _ in train_bar:
            if cuda:
                data = data.cuda()
            opt.zero_grad()

            vq_loss, data_recon, _ = model(data)
            recon_error = torch.mean((data_recon - data)**2) / data_variance
            loss = recon_error + vq_loss.mean()
            loss.backward()
            opt.step()

            train_bar.set_description('Epoch {}: loss {:.4f}'.format(
                epoch + 1,
                loss.mean().item()))

        model.eval()
        data_val = next(iter(val_loader))
        data_val, _ = data_val

        if cuda:
            data_val = data_val.cuda()
        _, data_recon_val, _ = model(data_val)

        plt.imsave(os.path.join(save_path, 'latest_val_recon.png'),
                   (make_grid(data_recon_val.cpu().data) +
                    0.5).numpy().transpose(1, 2, 0))
        plt.imsave(os.path.join(save_path, 'latest_val_orig.png'),
                   (make_grid(data_val.cpu().data) + 0.5).numpy().transpose(
                       1, 2, 0))

        model.train()

        torch.save(
            {
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'vq': vq.state_dict(),
            }, os.path.join(save_path, '{}_checkpoint.pth'.format(epoch)))
Example #13
0
    'data/train.en.txt', 'data/train.vi.txt', args.max_length)
English = utils.LanguageModel("English")
[English.add_line(i) for i in English_sentences]
Vietnamese = utils.LanguageModel("Vietnamese")
[Vietnamese.add_line(i) for i in Vietnamese_sentences]
English_tokens = [English.tokens_from_line(i) for i in English_sentences]
Vietnamese_tokens = [
    Vietnamese.tokens_from_line(i) for i in Vietnamese_sentences
]
encoder = Encoder(args, input_size=English.num_words)
decoder = Decoder(args, output_size=Vietnamese.num_words)

args.output_size = Vietnamese.num_words
if 'train' in args.mode:
    if args.checkpoint is True:
        encoder.load_state_dict(
            torch.load(os.path.join('model', 'encoder.pkl')))
        decoder.load_state_dict(
            torch.load(os.path.join('model', 'decoder.pkl')))
        encoder.eval()
        decoder.eval()
    utils.train_handler(args, encoder, decoder, English_tokens,
                        Vietnamese_tokens, English, Vietnamese)
    torch.save(encoder.state_dict(), os.path.join('model', 'encoder.pkl'))
    torch.save(decoder.state_dict(), os.path.join('model', 'decoder.pkl'))

elif 'translate' in args.mode:
    encoder.load_state_dict(torch.load(os.path.join('model', 'encoder.pkl')))
    decoder.load_state_dict(torch.load(os.path.join('model', 'decoder.pkl')))
    encoder.eval()
    decoder.eval()
    while True:
Example #14
0
def train_model(cfg):
    tensorboard_path = Path(
        utils.to_absolute_path("tensorboard")) / cfg.checkpoint_dir
    checkpoint_dir = Path(utils.to_absolute_path(cfg.checkpoint_dir))
    writer = SummaryWriter(tensorboard_path)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    encoder = Encoder(**cfg.model.encoder)
    decoder = Decoder(**cfg.model.decoder)
    encoder.to(device)
    decoder.to(device)

    optimizer = optim.Adam(chain(encoder.parameters(), decoder.parameters()),
                           lr=cfg.training.optimizer.lr)
    [encoder, decoder], optimizer = amp.initialize([encoder, decoder],
                                                   optimizer,
                                                   opt_level="O1")
    scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer,
        milestones=cfg.training.scheduler.milestones,
        gamma=cfg.training.scheduler.gamma)

    if cfg.resume:
        print("Resume checkpoint from: {}:".format(cfg.resume))
        resume_path = utils.to_absolute_path(cfg.resume)
        checkpoint = torch.load(resume_path,
                                map_location=lambda storage, loc: storage)
        encoder.load_state_dict(checkpoint["encoder"])
        decoder.load_state_dict(checkpoint["decoder"])
        optimizer.load_state_dict(checkpoint["optimizer"])
        amp.load_state_dict(checkpoint["amp"])
        scheduler.load_state_dict(checkpoint["scheduler"])
        global_step = checkpoint["step"]
    else:
        global_step = 0

    root_path = Path(utils.to_absolute_path("datasets")) / cfg.dataset.path
    dataset = SpeechDataset(root=root_path,
                            hop_length=cfg.preprocessing.hop_length,
                            sr=cfg.preprocessing.sr,
                            sample_frames=cfg.training.sample_frames)

    dataloader = DataLoader(dataset,
                            batch_size=cfg.training.batch_size,
                            shuffle=True,
                            num_workers=cfg.training.n_workers,
                            pin_memory=True,
                            drop_last=True)

    n_epochs = cfg.training.n_steps // len(dataloader) + 1
    start_epoch = global_step // len(dataloader) + 1

    for epoch in range(start_epoch, n_epochs + 1):
        average_recon_loss = average_vq_loss = average_perplexity = 0

        for i, (audio, mels, speakers) in enumerate(tqdm(dataloader), 1):
            audio, mels, speakers = audio.to(device), mels.to(
                device), speakers.to(device)

            optimizer.zero_grad()

            z, vq_loss, perplexity = encoder(mels)
            output = decoder(audio[:, :-1], z, speakers)
            recon_loss = F.cross_entropy(output.transpose(1, 2), audio[:, 1:])
            loss = recon_loss + vq_loss

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), 1)
            optimizer.step()
            scheduler.step()

            average_recon_loss += (recon_loss.item() - average_recon_loss) / i
            average_vq_loss += (vq_loss.item() - average_vq_loss) / i
            average_perplexity += (perplexity.item() - average_perplexity) / i

            global_step += 1

            if global_step % cfg.training.checkpoint_interval == 0:
                save_checkpoint(encoder, decoder, optimizer, amp, scheduler,
                                global_step, checkpoint_dir)

        writer.add_scalar("recon_loss/train", average_recon_loss, global_step)
        writer.add_scalar("vq_loss/train", average_vq_loss, global_step)
        writer.add_scalar("average_perplexity", average_perplexity,
                          global_step)

        print("epoch:{}, recon loss:{:.2E}, vq loss:{:.2E}, perpexlity:{:.3f}".
              format(epoch, average_recon_loss, average_vq_loss,
                     average_perplexity))
Example #15
0
def recover_models(device,
                   model="supervised",
                   m=256,
                   n=4,
                   chann_type="AWGN",
                   verbose=False):
    """
    Function to try to recover an already saved system to a channel
    Args:
        device (string): Current device that we are working in
        model (string): Model that wish to be recovered. Options: supervised or alternated
        chann_type (string): Channel type. Currently only AWGN available
        n (int): Length of the encoded messages
        m ((int): Total number of messages that can be encoded
    Returns:
        encoder/tx (Object): Recovered Tx/Encoder model
        decoder/rx (Object): Recovered Rx/Decoder model
    """
    try:
        if model == "supervised":
            enc_filename = "%s/%s_%d_%d_encoder.pth" % (MODELS_FOLDER,
                                                        chann_type, m, n)
            dec_filename = "%s/%s_%d_%d_decoder.pth" % (MODELS_FOLDER,
                                                        chann_type, m, n)

            encoder = Encoder(m=m, n=n)
            encoder.load_state_dict(torch.load(enc_filename))
            if verbose: print('Model loaded from %s.' % enc_filename)
            # Put them in the correct device and eval mode
            encoder.to(device)
            encoder.eval()

            decoder = Decoder(m=m, n=n)
            decoder.load_state_dict(torch.load(dec_filename))
            if verbose: print('Model loaded from %s.' % dec_filename)
            decoder.to(device)
            decoder.eval()

            return encoder, decoder
        else:
            tx_filename = "%s/%s_%d_%d_tx.pth" % (MODELS_FOLDER, chann_type, m,
                                                  n)
            rx_filename = "%s/%s_%d_%d_rx.pth" % (MODELS_FOLDER, chann_type, m,
                                                  n)

            tx = Transmitter(m=m, n=n)
            tx.load_state_dict(torch.load(tx_filename))
            if verbose: print('Model loaded from %s.' % tx_filename)
            # Put them in the correct device and eval mode
            tx.to(device)
            tx.eval()

            rx = Receiver(m=m, n=n)
            rx.load_state_dict(torch.load(rx_filename))
            if verbose: print('Model loaded from %s.' % rx_filename)
            rx.to(device)
            rx.eval()

            return tx, rx
    except:
        raise NameError("Something went wrong loading file for system (%s)" %
                        (chann_type))
Example #16
0
File: train.py Project: nnuq/tpu
def main(index, args):
    device = xm.xla_device()

    gen_net = Generator(args).to(device)
    dis_net = Discriminator(args).to(device)
    enc_net = Encoder(args).to(device)

    def weights_init(m):
        classname = m.__class__.__name__
        if classname.find('Conv2d') != -1:
            if args.init_type == 'normal':
                nn.init.normal_(m.weight.data, 0.0, 0.02)
            elif args.init_type == 'orth':
                nn.init.orthogonal_(m.weight.data)
            elif args.init_type == 'xavier_uniform':
                nn.init.xavier_uniform(m.weight.data, 1.)
            else:
                raise NotImplementedError('{} unknown inital type'.format(
                    args.init_type))
        elif classname.find('BatchNorm2d') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0.0)

    gen_net.apply(weights_init)
    dis_net.apply(weights_init)
    enc_net.apply(weights_init)

    ae_recon_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_recon_lr, (args.beta1, args.beta2))
    ae_reg_optimizer = torch.optim.Adam(
        itertools.chain(enc_net.parameters(), gen_net.parameters()),
        args.ae_reg_lr, (args.beta1, args.beta2))
    dis_optimizer = torch.optim.Adam(dis_net.parameters(), args.d_lr,
                                     (args.beta1, args.beta2))
    gen_optimizer = torch.optim.Adam(gen_net.parameters(), args.g_lr,
                                     (args.beta1, args.beta2))

    dataset = datasets.ImageDataset(args)
    train_loader = dataset.train
    valid_loader = dataset.valid
    para_loader = pl.ParallelLoader(train_loader, [device])

    fid_stat = str(pathlib.Path(
        __file__).parent.absolute()) + '/fid_stat/fid_stat_cifar10_test.npz'
    if not os.path.exists(fid_stat):
        download_stat_cifar10_test()

    is_best = True
    args.num_epochs = np.ceil(args.num_iter / len(train_loader))

    gen_scheduler = LinearLrDecay(gen_optimizer, args.g_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    dis_scheduler = LinearLrDecay(dis_optimizer, args.d_lr, 0,
                                  args.num_iter / 2, args.num_iter)
    ae_recon_scheduler = LinearLrDecay(ae_recon_optimizer, args.ae_recon_lr, 0,
                                       args.num_iter / 2, args.num_iter)
    ae_reg_scheduler = LinearLrDecay(ae_reg_optimizer, args.ae_reg_lr, 0,
                                     args.num_iter / 2, args.num_iter)

    # initial
    start_epoch = 0
    best_fid = 1e4

    # set writer
    if args.load_path:
        print(f'=> resuming from {args.load_path}')
        assert os.path.exists(args.load_path)
        checkpoint_file = os.path.join(args.load_path, 'Model',
                                       'checkpoint.pth')
        assert os.path.exists(checkpoint_file)
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch']
        best_fid = checkpoint['best_fid']
        gen_net.load_state_dict(checkpoint['gen_state_dict'])
        enc_net.load_state_dict(checkpoint['enc_state_dict'])
        dis_net.load_state_dict(checkpoint['dis_state_dict'])
        gen_optimizer.load_state_dict(checkpoint['gen_optimizer'])
        dis_optimizer.load_state_dict(checkpoint['dis_optimizer'])
        ae_recon_optimizer.load_state_dict(checkpoint['ae_recon_optimizer'])
        ae_reg_optimizer.load_state_dict(checkpoint['ae_reg_optimizer'])
        args.path_helper = checkpoint['path_helper']
        logger = create_logger(args.path_helper['log_path'])
        logger.info(
            f'=> loaded checkpoint {checkpoint_file} (epoch {start_epoch})')
    else:
        # create new log dir
        assert args.exp_name
        logs_dir = str(pathlib.Path(__file__).parent.parent) + '/logs'
        args.path_helper = set_log_dir(logs_dir, args.exp_name)
        logger = create_logger(args.path_helper['log_path'])

    logger.info(args)
    writer_dict = {
        'writer': SummaryWriter(args.path_helper['log_path']),
        'train_global_steps': start_epoch * len(train_loader),
        'valid_global_steps': start_epoch // args.val_freq,
    }

    # train loop
    for epoch in tqdm(range(int(start_epoch), int(args.num_epochs)),
                      desc='total progress'):
        lr_schedulers = (gen_scheduler, dis_scheduler, ae_recon_scheduler,
                         ae_reg_scheduler)
        train(device, args, gen_net, dis_net, enc_net, gen_optimizer,
              dis_optimizer, ae_recon_optimizer, ae_reg_optimizer, para_loader,
              epoch, writer_dict, lr_schedulers)
        if epoch and epoch % args.val_freq == 0 or epoch == args.num_epochs - 1:
            fid_score = validate(args, fid_stat, gen_net, writer_dict,
                                 valid_loader)
            logger.info(f'FID score: {fid_score} || @ epoch {epoch}.')
            if fid_score < best_fid:
                best_fid = fid_score
                is_best = True
            else:
                is_best = False
        else:
            is_best = False

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'gen_state_dict': gen_net.state_dict(),
                'dis_state_dict': dis_net.state_dict(),
                'enc_state_dict': enc_net.state_dict(),
                'gen_optimizer': gen_optimizer.state_dict(),
                'dis_optimizer': dis_optimizer.state_dict(),
                'ae_recon_optimizer': ae_recon_optimizer.state_dict(),
                'ae_reg_optimizer': ae_reg_optimizer.state_dict(),
                'best_fid': best_fid,
                'path_helper': args.path_helper
            }, is_best, args.path_helper['ckpt_path'])
Example #17
0
import torchvision.transforms as transforms
import tqdm
from torch.utils.data import DataLoader

from config import Config
from models import Encoder, SiameseNetwork

config = Config()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(config)
siamese_network = SiameseNetwork(config)

if config.load_model:
    encoder.load_state_dict(
        torch.load(config.saved_models_folder + '/encoder_epoch25.pth'))
    siamese_network.load_state_dict(
        torch.load(config.saved_models_folder +
                   '/siamese_network_epoch25.pth'))

encoder.to(device)
encoder.train()

siamese_network.to(device)
siamese_network.train()

params = list(encoder.parameters()) + list(siamese_network.parameters())

optimizer = torch.optim.Adam(params, lr=config.lr, betas=(0.9, 0.999))

transform = transforms.Compose([
Example #18
0
with open(word_map_file, 'r') as j:
    word_map = json.load(j)

decoder = DecoderWithAttention(attention_dim=attention_dim,
                               embed_dim=emb_dim,
                               decoder_dim=decoder_dim,
                               vocab_size=len(word_map),
                               dropout=dropout)

decoder.load_state_dict(
    torch.load('/scratch/scratch2/adsue/pretrained/decoder_dict.pkl'))
decoder = decoder.to(device)
decoder.eval()

encoder = Encoder()
encoder.load_state_dict(
    torch.load('/scratch/scratch2/adsue/pretrained/encoder_dict.pkl'))
encoder = encoder.to(device)
encoder.eval()
##########################################################################################################################

imsize = 256
image_transform = transforms.Compose(
    [transforms.Scale(int(imsize * 76 / 64)),
     transforms.RandomCrop(imsize)])

norm = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

folder = '/scratch/scratch2/adsue/caption_dataset/val2014/'
Example #19
0
from config import Config
from models import Encoder, SiameseNetwork

batch_size = 8
threshold = 0.9

config = Config()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(config)
siamese_network = SiameseNetwork(config)

encoder.load_state_dict(
    torch.load(config.saved_models_folder +
               '/encoder_epoch500_loss0.0009.pth'))
encoder.to(device)
encoder.eval()

siamese_network.load_state_dict(
    torch.load(config.saved_models_folder +
               '/siamese_network_epoch500_loss0.0009.pth'))
siamese_network.to(device)
siamese_network.eval()

transform = transforms.Compose(
    [transforms.Grayscale(num_output_channels=1),
     transforms.ToTensor()])
train_data = torchvision.datasets.ImageFolder(config.data_folder,
                                              transform=transform)
Example #20
0
    cifar_10_train_dt = CIFAR10(r'~/.torch',  download=True, transform=ToTensor())
    cifar_10_train_l = DataLoader(cifar_10_train_dt, batch_size=batch_size, shuffle=True, drop_last=True,
                                  pin_memory=torch.cuda.is_available())

    encoder = Encoder().to(device)
    loss_fn = DeepInfoMaxLoss().to(device)
    optim = Adam(encoder.parameters(), lr=1e-4)
    loss_optim = Adam(loss_fn.parameters(), lr=1e-4)

    epoch_restart = 20
    root = Path(f"./Models/run1")

    if epoch_restart is not None and root is not None:
        enc_file = root / Path('encoder' + str(epoch_restart) + '.wgt')
        loss_file = root / Path('loss' + str(epoch_restart) + '.wgt')
        encoder.load_state_dict(torch.load(str(enc_file)))
        loss_fn.load_state_dict(torch.load(str(loss_file)))

    for epoch in range(epoch_restart + 1, 30):
        batch = tqdm(cifar_10_train_l, total=len(cifar_10_train_dt) // batch_size)
        train_loss = []
        for x, target in batch:
            x = x.to(device)

            optim.zero_grad()
            loss_optim.zero_grad()
            y, M = encoder(x)
            # rotate images to create pairs for comparison
            M_prime = torch.cat((M[1:], M[0].unsqueeze(0)), dim=0) # put the first one to the last
            loss = loss_fn(y, M, M_prime)
            train_loss.append(loss.item())
Example #21
0
def main(args):

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Training on {device}")

    if not os.path.exists(args.models_dir):
        os.makedirs(args.models_dir)

    if args.build_vocab:
        print(
            f"Building vocabulary from captions at {args.captions_json} and with count threshold={args.threshold}"
        )
        vocab_object = build_vocab(args.captions_json, args.threshold)
        with open(args.vocab_path, "wb") as vocab_f:
            pickle.dump(vocab_object, vocab_f)
        print(
            f"Saved the vocabulary object to {args.vocab_path}, total size={len(vocab_object)}"
        )
    else:
        with open(args.vocab_path, 'rb') as f:
            vocab_object = pickle.load(f)
        print(
            f"Loaded the vocabulary object from {args.vocab_path}, total size={len(vocab_object)}"
        )

    if args.glove_embed_path is not None:
        with open(args.glove_embed_path, 'rb') as f:
            glove_embeddings = pickle.load(f)
        print(
            f"Loaded the glove embeddings from {args.glove_embed_path}, total size={len(glove_embeddings)}"
        )

        # We are using 300d glove embeddings
        args.embed_size = 300

        weights_matrix = np.zeros((len(vocab_object), args.embed_size))

        for word, index in vocab_object.word2index.items():
            if word in glove_embeddings:
                weights_matrix[index] = glove_embeddings[word]
            else:
                weights_matrix[index] = np.random.normal(
                    scale=0.6, size=(args.embed_size, ))

        weights_matrix = torch.from_numpy(weights_matrix).float().to(device)

    else:
        weights_matrix = None

    img_transforms = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    train_dataset = cocoDataset(args.image_root, args.captions_json,
                                vocab_object, img_transforms)
    train_dataloader = torch.utils.data.DataLoader(
        dataset=train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        collate_fn=collate_fn)

    encoder = Encoder(args.resnet_size, (3, 224, 224),
                      args.embed_size).to(device)
    decoder = Decoder(args.rnn_type, weights_matrix, len(vocab_object),
                      args.embed_size, args.hidden_size).to(device)

    encoder_learnable = list(encoder.linear.parameters())
    decoder_learnable = list(decoder.rnn.parameters()) + list(
        decoder.linear.parameters())
    if args.glove_embed_path is None:
        decoder_learnable = decoder_learnable + list(
            decoder.embedding.parameters())

    criterion = nn.CrossEntropyLoss()
    params = encoder_learnable + decoder_learnable
    optimizer = torch.optim.Adam(params, lr=args.learning_rate)

    start_epoch = 0

    if args.ckpt_path is not None:
        model_ckpt = torch.load(args.ckpt_path)
        start_epoch = model_ckpt['epoch'] + 1
        prev_loss = model_ckpt['loss']
        encoder.load_state_dict(model_ckpt['encoder'])
        decoder.load_state_dict(model_ckpt['decoder'])
        optimizer.load_state_dict(model_ckpt['optimizer'])
        print(
            f"Loaded model and optimizer state from {args.ckpt_path}; start epoch at {start_epoch}; prev loss={prev_loss}"
        )

    total_examples = len(train_dataloader)
    for epoch in range(start_epoch, args.num_epochs):
        for i, (images, captions, lengths) in enumerate(train_dataloader):
            images = images.to(device)
            captions = captions.to(device)
            targets = pack_padded_sequence(captions, lengths,
                                           batch_first=True).data

            image_embeddings = encoder(images)
            outputs = decoder(image_embeddings, captions, lengths)

            loss = criterion(outputs, targets)

            decoder.zero_grad()
            encoder.zero_grad()

            loss.backward()
            optimizer.step()

            if i % args.log_interval == 0:
                loss_val = "{:.4f}".format(loss.item())
                perplexity_val = "{:5.4f}".format(np.exp(loss.item()))
                print(
                    f"epoch=[{epoch}/{args.num_epochs}], iteration=[{i}/{total_examples}], loss={loss_val}, perplexity={perplexity_val}"
                )

        torch.save(
            {
                'epoch': epoch,
                'encoder': encoder.state_dict(),
                'decoder': decoder.state_dict(),
                'optimizer': optimizer.state_dict(),
                'loss': loss
            },
            os.path.join(args.models_dir,
                         'model-after-epoch-{}.ckpt'.format(epoch)))