Ejemplo n.º 1
0
def main(args):
	vecs_builder = VecsBuilder(vecs_path='./glove/glove.6B.300d.txt')
	vecs = vecs_builder.get_data()

	train_dataset = Loader(args.max_length,vecs,'train')
	train_loader = DataLoader(train_dataset, batch_size = args.batch_size, num_workers = 5)
	val_dataset = Loader(args.max_length,vecs,'val')
	val_loader = DataLoader(val_dataset, batch_size = args.batch_size)
	model = Classifier(args.embed_dim, args.hidden_dim,args.num_classes,args.num_hidden_layers)

	if torch.cuda.is_available():
		print('Cuda Functioning..')
		model.cuda()

	best_acc = 0
	automated_log = open('models/automated_log.txt','w+')
	automated_log.write('Epochs'+'\t'+'Train-Loss'+'\t'+'Train-Accuracy'+'\t'+'Validation Loss'+'\t'+'Validation Accuracy\n')

	for epoch in tqdm(range(args.num_epochs)):
		train_loss,train_acc = train(model,train_loader)
		val_loss,val_acc = eval(model,val_loader)
		train_acc = train_acc/train_dataset.num_samples
		val_acc = val_acc/val_dataset.num_samples
		# print('Epoch : ',epoch)
		# print('Train Loss : ',train_loss)
		# print('Train Acc : ',train_acc)
		# print('Validation Loss : ',val_loss)
		# print('Validation Acc : ',val_acc)
		automated_log.write(str(epoch)+'\t'+str(train_loss)+'\t'+str(train_acc)+'\t'+str(val_loss)+'\t'+str(val_acc)+'\n')
		if epoch%10==0:
			model_name = 'models/model_'+str(epoch)+'.pkl'
			torch.save(model.state_dict(),model_name)
		if val_acc>best_acc:
			best_acc = val_acc
			best_model = 'best.pkl'
			torch.save(model.state_dict(),best_model)
			f = open('models/best.txt','w+')
			report = 'Epoch : '+str(epoch)+'\t Validation Accuracy : '+str(best_acc)
			f.write(report)
			f.close()
			print('Best Model Saved with Valdn Accuracy :',val_acc)
	automated_log.close()
Ejemplo n.º 2
0
class Solver(object):
    def __init__(self, config):

        # Configurations
        self.config = config

        # Build the models
        self.build_models()

    def build_models(self):

        # Models
        self.net = Classifier().to(self.config['device'])

        # Optimizers
        self.optimizer = getattr(torch.optim, self.config['optimizer'])(
            self.net.parameters(),
            lr=self.config['lr'],
        )

        # Citerion
        self.criterion = nn.CrossEntropyLoss(reduce=False)

        # Record
        logging.info(self.net)

    def save_model(self, filename):
        save_path = os.path.join(self.config['save_path'], f'{filename}')
        try:
            logging.info(
                f'Saved best Neural network ckeckpoints into {save_path}')
            torch.save(self.net.state_dict(),
                       save_path,
                       _use_new_zipfile_serialization=False)
        except:
            logging.error(f'Error saving weights to {save_path}')

    def restore_model(self, filename):
        weight_path = os.path.join(self.config['save_path'], f'{filename}')
        try:
            logging.info(f'Loading the trained Extractor from {weight_path}')
            self.net.load_state_dict(
                torch.load(weight_path,
                           map_location=lambda storage, loc: storage))

        except:
            logging.error(f'Error loading weights from {weight_path}')
Ejemplo n.º 3
0
def run_training(opt):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    work_dir, epochs, train_batch, valid_batch, weights = \
        opt.work_dir, opt.epochs, opt.train_bs, opt.valid_bs, opt.weights

    # Directories
    last = os.path.join(work_dir, 'last.pt')
    best = os.path.join(work_dir, 'best.pt')

    # --------------------------------------
    # Setup train and validation set
    # --------------------------------------
    data = pd.read_csv(opt.train_csv)
    images_path = opt.data_dir

    n_classes = 6  # fixed coding :V

    data['class'] = data.apply(lambda row: categ[row["class"]], axis=1)

    train_loader, val_loader = prepare_dataloader(data,
                                                  opt.fold,
                                                  train_batch,
                                                  valid_batch,
                                                  opt.img_size,
                                                  opt.num_workers,
                                                  data_root=images_path)

    # if not opt.ovr_val:
    #     handwritten_data = pd.read_csv(opt.handwritten_csv)
    #     printed_data = pd.read_csv(opt.printed_csv)
    #     handwritten_data['class'] = handwritten_data.apply(lambda row: categ[row["class"]], axis =1)
    #     printed_data['class'] = printed_data.apply(lambda row: categ[row["class"]], axis =1)
    #     _, handwritten_val_loader = prepare_dataloader(
    #         handwritten_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    #     _, printed_val_loader = prepare_dataloader(
    #         printed_data, opt.fold, train_batch, valid_batch, opt.img_size, opt.num_workers, data_root=images_path)

    # --------------------------------------
    # Models
    # --------------------------------------

    model = Classifier(model_name=opt.model_name,
                       n_classes=n_classes,
                       pretrained=True).to(device)

    if opt.weights is not None:
        cp = torch.load(opt.weights)
        model.load_state_dict(cp['model'])

    # -------------------------------------------
    # Setup optimizer, scheduler, criterion loss
    # -------------------------------------------

    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=1e-6)
    scheduler = CosineAnnealingWarmRestarts(optimizer,
                                            T_0=10,
                                            T_mult=1,
                                            eta_min=1e-6,
                                            last_epoch=-1)
    scaler = GradScaler()

    loss_tr = nn.CrossEntropyLoss().to(device)
    loss_fn = nn.CrossEntropyLoss().to(device)

    # --------------------------------------
    # Setup training
    # --------------------------------------
    if os.path.exists(work_dir) == False:
        os.mkdir(work_dir)

    best_loss = 1e5
    start_epoch = 0
    best_epoch = 0  # for early stopping

    if opt.resume == True:
        checkpoint = torch.load(last)

        start_epoch = checkpoint["epoch"]
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        scheduler.load_state_dict(checkpoint["scheduler"])
        best_loss = checkpoint["best_loss"]

    # --------------------------------------
    # Start training
    # --------------------------------------
    print("[INFO] Start training...")
    for epoch in range(start_epoch, epochs):
        train_one_epoch(epoch,
                        model,
                        loss_tr,
                        optimizer,
                        train_loader,
                        device,
                        scheduler=scheduler,
                        scaler=scaler)
        with torch.no_grad():
            if opt.ovr_val:
                val_loss = valid_one_epoch_overall(epoch,
                                                   model,
                                                   loss_fn,
                                                   val_loader,
                                                   device,
                                                   scheduler=None)
            else:
                val_loss = valid_one_epoch(epoch,
                                           model,
                                           loss_fn,
                                           handwritten_val_loader,
                                           printed_val_loader,
                                           device,
                                           scheduler=None)

            if val_loss < best_loss:
                best_loss = val_loss
                best_epoch = epoch
                torch.save(
                    {
                        'epoch': epoch,
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict(),
                        'best_loss': best_loss
                    }, os.path.join(best))

                print('best model found for epoch {}'.format(epoch + 1))

        torch.save(
            {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler.state_dict(),
                'best_loss': best_loss
            }, os.path.join(last))

        if epoch - best_epoch > opt.patience:
            print("Early stop achieved at", epoch + 1)
            break

    del model, optimizer, train_loader, val_loader, scheduler, scaler
    torch.cuda.empty_cache()
Ejemplo n.º 4
0
class Agent():
    def __init__(self, state_size, action_size, config):
        self.env_name = config["env_name"]
        self.state_size = state_size
        self.action_size = action_size
        self.seed = config["seed"]
        self.clip = config["clip"]
        self.device = 'cuda'
        print("Clip ", self.clip)
        print("cuda ", torch.cuda.is_available())
        self.double_dqn = config["DDQN"]
        print("Use double dqn", self.double_dqn)
        self.lr_pre = config["lr_pre"]
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]
        self.tau = config["tau"]
        print("self tau", self.tau)
        self.gamma = 0.99
        self.fc1 = config["fc1_units"]
        self.fc2 = config["fc2_units"]
        self.fc3 = config["fc3_units"]
        self.qnetwork_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device)
        self.qnetwork_target = QNetwork(state_size, action_size, self.fc1, self.fc2,self.fc3,  self.seed).to(self.device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.lr)
        self.soft_update(self.qnetwork_local, self.qnetwork_target, 1)
        
        self.q_shift_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device)
        self.q_shift_target = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device)
        self.optimizer_shift = optim.Adam(self.q_shift_local.parameters(), lr=self.lr)
        self.soft_update(self.q_shift_local, self.q_shift_target, 1)
         
        self.R_local = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3,  self.seed).to(self.device)
        self.R_target = QNetwork(state_size, action_size, self.fc1, self.fc2, self.fc3, self.seed).to(self.device)
        self.optimizer_r = optim.Adam(self.R_local.parameters(), lr=self.lr)
        self.soft_update(self.R_local, self.R_target, 1) 

        self.expert_q = DQNetwork(state_size, action_size, seed=self.seed).to(self.device)
        self.expert_q.load_state_dict(torch.load('checkpoint.pth'))
        self.memory = Memory(action_size, config["buffer_size"], self.batch_size, self.seed, self.device)
        self.t_step = 0
        self.steps = 0
        self.predicter = Classifier(state_size, action_size, self.seed).to(self.device)
        self.optimizer_pre = optim.Adam(self.predicter.parameters(), lr=self.lr_pre)
        pathname = "lr_{}_batch_size_{}_fc1_{}_fc2_{}_fc3_{}_seed_{}".format(self.lr, self.batch_size, self.fc1, self.fc2, self.fc3, self.seed)
        pathname += "_clip_{}".format(config["clip"])
        pathname += "_tau_{}".format(config["tau"])
        now = datetime.now()    
        dt_string = now.strftime("%d_%m_%Y_%H:%M:%S")
        pathname += dt_string
        tensorboard_name = str(config["locexp"]) + '/runs/' + pathname
        self.writer = SummaryWriter(tensorboard_name)
        print("summery writer ", tensorboard_name)
        self.average_prediction = deque(maxlen=100)
        self.average_same_action = deque(maxlen=100)
        self.all_actions = []
        for a in range(self.action_size):
            action = torch.Tensor(1) * 0 +  a
            self.all_actions.append(action.to(self.device))
    
    
    def learn(self, memory):
        logging.debug("--------------------------New episode-----------------------------------------------")
        states, next_states, actions, dones = memory.expert_policy(self.batch_size)
        self.steps += 1
        self.state_action_frq(states, actions)
        self.compute_shift_function(states, next_states, actions, dones)
        for i in range(1):
            for a in range(self.action_size):
                action =  torch.ones([self.batch_size, 1], device= self.device) * a
                self.compute_r_function(states, action)

        self.compute_q_function(states, next_states, actions, dones)
        self.soft_update(self.q_shift_local, self.q_shift_target, self.tau)
        self.soft_update(self.R_local, self.R_target, self.tau)
        self.soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)
        return
    
    def learn_predicter(self, memory):
        """

        """
        states, next_states, actions, dones = memory.expert_policy(self.batch_size)
        self.state_action_frq(states, actions)
    
    def state_action_frq(self, states, action):
        """ Train classifer to compute state action freq
        """ 
        self.predicter.train()
        output = self.predicter(states, train=True)
        output = output.squeeze(0)
        # logging.debug("out predicter {})".format(output))

        y = action.type(torch.long).squeeze(1)
        #print("y shape", y.shape)
        loss = nn.CrossEntropyLoss()(output, y)
        self.optimizer_pre.zero_grad()
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.predicter.parameters(), 1)
        self.optimizer_pre.step()
        self.writer.add_scalar('Predict_loss', loss, self.steps)
        self.predicter.eval()

    def test_predicter(self, memory):
        """

        """
        self.predicter.eval()
        same_state_predition = 0
        for i in range(memory.idx):
            states = memory.obses[i]
            actions = memory.actions[i]
        
            states = torch.as_tensor(states, device=self.device).unsqueeze(0)
            actions = torch.as_tensor(actions, device=self.device)
            output = self.predicter(states)   
            output = F.softmax(output, dim=1)
            # create one hot encode y from actions
            y = actions.type(torch.long).item()
            p =torch.argmax(output.data).item()
            if y==p:
                same_state_predition += 1

        
        #self.average_prediction.append(same_state_predition)
        #average_pred = np.mean(self.average_prediction)
        #self.writer.add_scalar('Average prediction acc', average_pred, self.steps)
        #logging.debug("Same prediction {} of 100".format(same_state_predition))
        text = "Same prediction {} of {} ".format(same_state_predition, memory.idx)
        print(text)
        # self.writer.add_scalar('Action prediction acc', same_state_predition, self.steps)
        self.predicter.train()


    def get_action_prob(self, states, actions):
        """
        """
        actions = actions.type(torch.long)
        # check if action prob is zero
        output = self.predicter(states)
        output = F.softmax(output, dim=1)
        # print("get action_prob ", output) 
        # output = output.squeeze(0)
        action_prob = output.gather(1, actions)
        action_prob = action_prob + torch.finfo(torch.float32).eps
        # check if one action if its to small
        if action_prob.shape[0] == 1:
            if action_prob.cpu().detach().numpy()[0][0] < 1e-4:
                return None
        # logging.debug("action_prob {})".format(action_prob))
        action_prob = torch.log(action_prob)
        action_prob = torch.clamp(action_prob, min= self.clip, max=0)
        return action_prob

    def compute_shift_function(self, states, next_states, actions, dones):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        actions = actions.type(torch.int64)
        with torch.no_grad():
            # Get max predicted Q values (for next states) from target model
            if self.double_dqn:
                qt = self.q_shift_local(next_states)
                max_q, max_actions = qt.max(1)
                Q_targets_next = self.qnetwork_target(next_states).gather(1, max_actions.unsqueeze(1))
            else:
                Q_targets_next = self.qnetwork_target(next_states).max(1)[0].unsqueeze(1)
            # Compute Q targets for current states
            Q_targets = (self.gamma * Q_targets_next * (dones))

        # Get expected Q values from local model
        Q_expected = self.q_shift_local(states).gather(1, actions)

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer_shift.zero_grad()
        loss.backward()
        self.writer.add_scalar('Shift_loss', loss, self.steps)
        self.optimizer_shift.step()


    def compute_r_function(self, states, actions, debug=False, log=False):
        """

        """
        actions = actions.type(torch.int64)
        
        # sum all other actions
        # print("state shape ", states.shape)
        size = states.shape[0]
        idx = 0
        all_zeros = []
        with torch.no_grad():
            y_shift = self.q_shift_target(states).gather(1, actions)
            log_a = self.get_action_prob(states, actions)
            index_list = index_None_value(log_a)
            # print("is none", index_list)
            if index_list is None:
                return


            y_r_part1 = log_a - y_shift
            y_r_part2 =  torch.empty((size, 1), dtype=torch.float32).to(self.device)
            for a, s in zip(actions, states):
                y_h = 0
                taken_actions = 0
                for b in self.all_actions:
                    b = b.type(torch.int64).unsqueeze(1)
                    n_b = self.get_action_prob(s.unsqueeze(0), b)
                    if torch.eq(a, b) or n_b is None:
                        logging.debug("best action {} ".format(a))
                        logging.debug("n_b action {} ".format(b))
                        logging.debug("n_b {} ".format(n_b))
                        continue
                    taken_actions += 1
                    r_hat = self.R_target(s.unsqueeze(0)).gather(1, b)

                    y_s = self.q_shift_target(s.unsqueeze(0)).gather(1, b)
                    n_b = n_b - y_s

                    y_h += (r_hat - n_b)
                    if debug:
                        print("action", b.item())
                        print("r_pre {:.3f}".format(r_hat.item()))
                        print("n_b {:.3f}".format(n_b.item()))
                if taken_actions == 0:
                    all_zeros.append(idx)
                else:
                    y_r_part2[idx] = (1. / taken_actions)  * y_h
                idx += 1
            #print(y_r_part2, y_r_part1)
            y_r = y_r_part1 + y_r_part2
            #print("_________________")
            #print("r update zeros ", len(all_zeros))
        if len(index_list) > 0:
            print("none list", index_list)
        y = self.R_local(states).gather(1, actions)
        if log:
            text = "Action {:.2f}  y target {:.2f} =  n_a {:.2f} + {:.2f} and pre{:.2f}".format(actions.item(), y_r.item(), y_r_part1.item(), y_r_part2.item(), y.item())
            logging.debug(text)

        if debug:
            print("expet action ", actions.item())
            # print("y r {:.3f}".format(y.item()))
            # print("log a prob {:.3f}".format(log_a.item()))
            # print("n_a {:.3f}".format(y_r_part1.item()))
            print("Correct action p {:.3f} ".format(y.item()))
            print("Correct action target {:.3f} ".format(y_r.item()))
            print("part1 corret action {:.2f} ".format(y_r_part1.item()))
            print("part2 incorret action {:.2f} ".format(y_r_part2.item()))
        
        #print("y", y.shape)
        #print("y_r", y_r.shape)
        
        r_loss = F.mse_loss(y, y_r)
        
        #con = input()
        #sys.exit()
        # Minimize the loss
        self.optimizer_r.zero_grad()
        r_loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.R_local.parameters(), 5)
        self.optimizer_r.step()
        self.writer.add_scalar('Reward_loss', r_loss, self.steps)
        if debug:
            print("after update r pre ", self.R_local(states).gather(1, actions).item())
            print("after update r target ", self.R_target(states).gather(1, actions).item())
        # ------------------- update target network ------------------- #
        #self.soft_update(self.R_local, self.R_target, 5e-3)
        if debug:
            print("after soft upda r target ", self.R_target(states).gather(1, actions).item())
    
    def compute_q_function(self, states, next_states, actions, dones, debug=False, log= False):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        actions = actions.type(torch.int64)
        if debug:
            print("---------------q_update------------------")
            print("expet action ", actions.item())
            print("state ", states)
        with torch.no_grad():
            # Get max predicted Q values (for next states) from target model
            if self.double_dqn:
                qt = self.qnetwork_local(next_states)
                max_q, max_actions = qt.max(1)
                Q_targets_next = self.qnetwork_target(next_states).gather(1, max_actions.unsqueeze(1))
            else:
                Q_targets_next = self.qnetwork_target(next_states).max(1)[0].unsqueeze(1)
            # Compute Q targets for current states
            rewards = self.R_target(states).gather(1, actions)
            Q_targets = rewards + (self.gamma * Q_targets_next * (dones))
            if debug:
                print("reward  {}".format(rewards.item()))
                print("Q target next {}".format(Q_targets_next.item()))
                print("Q_target {}".format(Q_targets.item()))



        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)
        if log:
            text = "Action {:.2f}  q target {:.2f} =  r_a {:.2f} + target {:.2f} and pre{:.2f}".format(actions.item(), Q_targets.item(), rewards.item(), Q_targets_next.item(), Q_expected.item())
            logging.debug(text)
        if debug:
            print("q for a {}".format(Q_expected))
        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        self.writer.add_scalar('Q_loss', loss, self.steps)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        if debug:
            print("q after update {}".format(self.qnetwork_local(states)))
            print("q loss {}".format(loss.item()))


        # ------------------- update target network ------------------- #



    def dqn_train(self, n_episodes=2000, max_t=1000, eps_start=1.0, eps_end=0.01, eps_decay=0.995):
        env =  gym.make('LunarLander-v2')
        scores = []                        # list containing scores from each episode
        scores_window = deque(maxlen=100)  # last 100 scores
        eps = eps_start
        for i_episode in range(1, n_episodes+1):
            state = env.reset()
            score = 0
            for t in range(max_t):
                self.t_step += 1
                action = self.dqn_act(state, eps)
                next_state, reward, done, _ = env.step(action)
                self.step(state, action, reward, next_state, done)
                state = next_state
                score += reward
                if done:
                    self.test_q()
                    break
            scores_window.append(score)       # save most recent score
            scores.append(score)              # save most recent score
            eps = max(eps_end, eps_decay*eps) # decrease epsilon
            print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)), end="")
            if i_episode % 100 == 0:
                print('\rEpisode {}\tAverage Score: {:.2f}'.format(i_episode, np.mean(scores_window)))
            if np.mean(scores_window)>=200.0:
                print('\nEnvironment solved in {:d} episodes!\tAverage Score: {:.2f}'.format(i_episode-100, np.mean(scores_window)))
                break



    def test_policy(self):
        env =  gym.make('LunarLander-v2')
        logging.debug("new episode")
        average_score = [] 
        average_steps = []
        average_action = []
        for i in range(5):
            state = env.reset()
            score = 0
            same_action = 0
            logging.debug("new episode")
            for t in range(200):
                state = torch.from_numpy(state).float().unsqueeze(0).to(self.device)
                q_expert = self.expert_q(state)
                q_values = self.qnetwork_local(state)
                logging.debug("q expert a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}".format(q_expert.data[0][0], q_expert.data[0][1], q_expert.data[0][2], q_expert.data[0][3]))
                logging.debug("q values a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}  )".format(q_values.data[0][0], q_values.data[0][1], q_values.data[0][2], q_values.data[0][3]))
                action = torch.argmax(q_values).item()
                action_e = torch.argmax(q_expert).item()
                if action == action_e:
                    same_action += 1
                next_state, reward, done, _ = env.step(action)
                state = next_state
                score += reward
                if done:
                    average_score.append(score)
                    average_steps.append(t)
                    average_action.append(same_action)
                    break
        mean_steps = np.mean(average_steps)
        mean_score = np.mean(average_score)
        mean_action= np.mean(average_action)
        self.writer.add_scalar('Ave_epsiode_length', mean_steps , self.steps)
        self.writer.add_scalar('Ave_same_action', mean_action, self.steps)
        self.writer.add_scalar('Ave_score', mean_score, self.steps)


    def step(self, state, action, reward, next_state, done):
        # Save experience in replay memory
        self.memory.add(state, action, reward, next_state, done)

        # Learn every UPDATE_EVERY time steps.
        self.t_step = (self.t_step + 1) % 4
        if self.t_step == 0:
            # If enough samples are available in memory, get random subset and learn
            if len(self.memory) > self.batch_size:
                experiences = self.memory.sample()
                self.update_q(experiences)


    def dqn_act(self, state, eps=0.):
        """Returns actions for given state as per current policy.

        Params
        ======
            state (array_like): current state
            eps (float): epsilon, for epsilon-greedy action selection
        """
        state = torch.from_numpy(state).float().unsqueeze(0).to(device)
        self.qnetwork_local.eval()
        with torch.no_grad():
            action_values = self.qnetwork_local(state)
        self.qnetwork_local.train()

        # Epsilon-greedy action selection
        if random.random() > eps:
            return np.argmax(action_values.cpu().data.numpy())
        else:
            return random.choice(np.arange(self.action_size))

    def update_q(self, experiences, debug=False):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        states, actions, rewards, next_states, dones = experiences
        # Get max predicted Q values (for next states) from target model
        with torch.no_grad():
            Q_targets_next = self.qnetwork_target(next_states).max(1)[0].unsqueeze(1)
            # Compute Q targets for current states
            Q_targets = rewards + (self.gamma * Q_targets_next * (1 - dones))

        # Get expected Q values from local model
        Q_expected = self.qnetwork_local(states).gather(1, actions)
        if debug:
            print("----------------------")
            print("----------------------")
            print("Q target", Q_targets)
            print("pre", Q_expected)
            print("all local",self.qnetwork_local(states))

        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()

        # ------------------- update target network ------------------- #
        self.soft_update(self.qnetwork_local, self.qnetwork_target)



    def test_q(self):
        experiences = self.memory.test_sample()
        self.update_q(experiences, True)

    def test_q_value(self, memory):
        same_action = 0
        test_elements = memory.idx
        all_diff = 0
        error = True
        self.predicter.eval()
        for i in range(test_elements):
            # print("lop", i)
            states = memory.obses[i]
            next_states = memory.next_obses[i]
            actions = memory.actions[i]
            dones = memory.not_dones[i]
            states = torch.as_tensor(states, device=self.device).unsqueeze(0)
            next_states = torch.as_tensor(next_states, device=self.device)
            actions = torch.as_tensor(actions, device=self.device)
            dones = torch.as_tensor(dones, device=self.device)
            with torch.no_grad():
                output = self.predicter(states)
                output = F.softmax(output, dim=1)
                q_values = self.qnetwork_local(states)
                expert_values = self.expert_q(states)
                print("q values ", q_values)
                print("ex values  ", expert_values)
                best_action = torch.argmax(q_values).item()
                actions = actions.type(torch.int64)
                q_max = q_values.max(1)
                
                #print("q values", q_values)
                q = q_values[0][actions.item()].item()
                #print("q action", q)
                max_q =  q_max[0].data.item()
                diff = max_q - q
                all_diff += diff
                #print("q best", max_q)
                #print("difference ", diff)
            if  actions.item() != best_action:
                r = self.R_local(states)
                rt = self.R_target(states)
                qt = self.qnetwork_target(states)
                logging.debug("------------------false action --------------------------------")
                logging.debug("expert action  {})".format(actions.item()))
                logging.debug("out predicter a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}  )".format(output.data[0][0], output.data[0][1], output.data[0][2], output.data[0][3]))
                logging.debug("q values a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}  )".format(q_values.data[0][0], q_values.data[0][1], q_values.data[0][2], q_values.data[0][3]))
                logging.debug("q target a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}  )".format(qt.data[0][0], qt.data[0][1], qt.data[0][2], qt.data[0][3]))
                logging.debug("rewards a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}  )".format(r.data[0][0], r.data[0][1], r.data[0][2], r.data[0][3]))
                logging.debug("re target a0: {:.2f} a1: {:.2f} a2: {:.2f} a3: {:.2f}  )".format(rt.data[0][0], rt.data[0][1], rt.data[0][2], rt.data[0][3]))
                """ 
                logging.debug("---------Reward Function------------")
                action = torch.Tensor(1) * 0 +  0
                self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True)
                action = torch.Tensor(1) * 0 +  1
                self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True)
                action = torch.Tensor(1) * 0 +  2
                self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True)
                action = torch.Tensor(1) * 0 +  3
                self.compute_r_function(states, action.unsqueeze(0).to(self.device), log= True)
                logging.debug("------------------Q Function --------------------------------")
                action = torch.Tensor(1) * 0 +  0
                self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True)
                action = torch.Tensor(1) * 0 +  1
                self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True)
                action = torch.Tensor(1) * 0 +  2
                self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True)
                action = torch.Tensor(1) * 0 +  3
                self.compute_q_function(states, next_states.unsqueeze(0), action.unsqueeze(0).to(self.device), dones, log= True)
                """
                

            if  actions.item() == best_action:
                same_action += 1
                continue
                print("-------------------------------------------------------------------------------")
                print("state ", i)
                print("expert ", actions)
                print("q values", q_values.data)
                print("action prob predicter  ", output.data)
                self.compute_r_function(states, actions.unsqueeze(0), True)
                self.compute_q_function(states, next_states.unsqueeze(0), actions.unsqueeze(0), dones, True)
            else:
                if error:
                    continue
                    print("-------------------------------------------------------------------------------")
                    print("expert action ", actions.item())
                    print("best action q ", best_action)
                    print(i)
                    error = False
                continue
                # logging.debug("experte action  {} q fun {}".format(actions.item(), q_values))
                print("-------------------------------------------------------------------------------")
                print("state ", i)
                print("expert ", actions)
                print("q values", q_values.data)
                print("action prob predicter  ", output.data)
                self.compute_r_function(states, actions.unsqueeze(0), True)
                self.compute_q_function(states, next_states.unsqueeze(0), actions.unsqueeze(0), dones, True)


        self.writer.add_scalar('diff', all_diff, self.steps)
        self.average_same_action.append(same_action)
        av_action = np.mean(self.average_same_action)
        self.writer.add_scalar('Same_action', same_action, self.steps)
        print("Same actions {}  of {}".format(same_action, test_elements))
        self.predicter.train()


    def soft_update(self, local_model, target_model, tau=4):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        # print("use tau", tau)
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
    
    def save(self, filename):
        """

        """
        mkdir("", filename)
        torch.save(self.predicter.state_dict(), filename + "_predicter.pth")
        torch.save(self.optimizer_pre.state_dict(), filename + "_predicter_optimizer.pth")
        torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth")
        """
        torch.save(self.optimizer_q.state_dict(), filename + "_q_net_optimizer.pth")
        torch.save(self.q_shift_local.state_dict(), filename + "_q_shift_net.pth")
        torch.save(self.optimizer_q_shift.state_dict(), filename + "_q_shift_net_optimizer.pth")
        """
        print("save models to {}".format(filename))
    
    def load(self, filename):
        self.predicter.load_state_dict(torch.load(filename + "_predicter.pth"))
        self.optimizer_pre.load_state_dict(torch.load(filename + "_predicter_optimizer.pth"))
        print("Load models to {}".format(filename))
Ejemplo n.º 5
0
                        default=28,
                        help='Image size (default to be squared images)')
    args = parser.parse_args()

    # load data
    transform = transforms.Compose(
        [transforms.CenterCrop(args.image_size),
         transforms.ToTensor()])
    trainset = datasets.MNIST('./data',
                              train=True,
                              download=True,
                              transform=transform)
    dataloader = torch.utils.data.DataLoader(trainset,
                                             batch_size=args.batch_size,
                                             shuffle=True,
                                             num_workers=1)

    model = Classifier(args)
    crit = nn.CrossEntropyLoss()
    opt = optim.Adam(model.parameters(), lr=args.learning_rate)
    writer = SummaryWriter(args.tensorboard_path)
    print(f"Training for {args.epochs} epochs")
    step = 0
    for epoch in range(args.epochs):
        loss, step = train(args, dataloader, model, opt, crit, step)
        avg_loss = np.average(loss)
        print(f"Epoch {epoch} loss = {avg_loss}")
        writer.add_scalar('class_loss', avg_loss, step)

    torch.save(model.state_dict(), args.save_path)
Ejemplo n.º 6
0
    epoch_loss = 0
    for i, batch in enumerate(dataloader):
        video_data, audio_data, labels = batch
        video_data = video_data.to(device)
        audio_data = audio_data.to(device)
        video_data = video_data.permute(1, 0, 2)
        audio_data = audio_data.permute(1, 0, 2)
        optimizer.zero_grad()
        output = model(video_data, audio_data)
        loss = criterion(output, labels)
        epoch_loss += loss.item()
        loss.backward()
        optimizer.step()
        print('.', end='', flush=True)
    epoch_end_time = datetime.datetime.now()
    epoch_exec_time = epoch_end_time - epoch_start_time
    print(
        f'\nepoch: {epoch}, loss: {epoch_loss/(epoch_size/batch_size)}, executed in: {str(epoch_exec_time)}'
    )
end_time = datetime.datetime.now()
print(f"end time: {str(end_time)}")
exec_time = end_time - start_time
print(f"executed in: {str(exec_time)}")
if device.type == 'cuda':
    print(torch.cuda.get_device_name(0))
    print('Memory Usage:')
    print('Allocated:', round(torch.cuda.memory_allocated(0) / 1024**3, 1),
          'GB')
    print('Cached:   ', round(torch.cuda.memory_cached(0) / 1024**3, 1), 'GB')
torch.save(model.state_dict(), f'classifier2_{end_time.isoformat()}.pt')
Ejemplo n.º 7
0
            sensitivity = TP / (TP + FN)
            specificity = TN / (TN + FP)
            precision = TP / (TP + FP)
            print(
                '[%s] Avg loss: %.5f, accu: %.5f, precision: %.5f, sensitivity: %.5f, specificity: %.5f'
                % (stage_info[stage], avg_loss / counter, avg_accu / counter,
                   precision, sensitivity, specificity))

            if stage == 'valid':
                with open(
                        os.path.join(
                            args.output_dir,
                            'epoch-%03d-auc-%.3f.json' % (epoch, precision)),
                        'w') as f:
                    json.dump(outputs, f, indent=2)

            if stage == 'valid' and precision > best:
                best = precision
                torch.save(model.state_dict(), 'best.ckpt')
                print('Save best model with auc score:', best)

        except KeyboardInterrupt:
            torch.save(model.state_dict(), args.pause_ckpt)
            print('save temporary model into %s' % args.pause_ckpt)
            terminated = True
            break

print('Total:', time.time() - start)
print('Finished Training')
Ejemplo n.º 8
0
class Agent():
    def __init__(self, state_size, action_size, config):
        self.seed = config["seed"]
        torch.manual_seed(self.seed)
        np.random.seed(seed=self.seed)
        random.seed(self.seed)
        self.env = gym.make(config["env_name"])
        self.env.seed(self.seed)
        self.state_size = state_size
        self.action_size = action_size
        self.clip = config["clip"]
        self.device = 'cuda'
        print("Clip ", self.clip)
        print("cuda ", torch.cuda.is_available())
        self.double_dqn = config["DDQN"]
        print("Use double dqn", self.double_dqn)
        self.lr_pre = config["lr_pre"]
        self.batch_size = config["batch_size"]
        self.lr = config["lr"]
        self.tau = config["tau"]
        print("self tau", self.tau)
        self.gamma = 0.99
        self.target_entropy = -torch.prod(torch.Tensor(action_size).to(self.device)).item()
        self.fc1 = config["fc1_units"]
        self.fc2 = config["fc2_units"]
        self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
        self.alpha = self.log_alpha.exp()
        self.alpha_optim = optim.Adam([self.log_alpha], lr=config["lr_alpha"])
        self.policy = SACActor(state_size, action_size, self.seed).to(self.device)
        self.policy_optim = optim.Adam(self.policy.parameters(), lr=config["lr_policy"])
        
        self.qnetwork_local = QNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device)
        self.qnetwork_target = QNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device)
        self.optimizer = optim.Adam(self.qnetwork_local.parameters(), lr=self.lr)
        self.soft_update(self.qnetwork_local, self.qnetwork_target, 1)
        
        self.q_shift_local = SQNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device)
        self.q_shift_target = SQNetwork(state_size, action_size,self.seed, self.fc1, self.fc2).to(self.device)
        self.optimizer_shift = optim.Adam(self.q_shift_local.parameters(), lr=self.lr)
        self.soft_update(self.q_shift_local, self.q_shift_target, 1)
         
        self.R_local = SQNetwork(state_size, action_size, self.seed,  self.fc1, self.fc2).to(self.device)
        self.R_target = SQNetwork(state_size, action_size, self.seed, self.fc1, self.fc2).to(self.device)
        self.optimizer_r = optim.Adam(self.R_local.parameters(), lr=self.lr)
        self.soft_update(self.R_local, self.R_target, 1) 

        self.steps = 0
        self.predicter = Classifier(state_size, action_size, self.seed, 256, 256).to(self.device)
        self.optimizer_pre = optim.Adam(self.predicter.parameters(), lr=self.lr_pre)
        pathname = "lr_{}_batch_size_{}_fc1_{}_fc2_{}_seed_{}".format(self.lr, self.batch_size, self.fc1, self.fc2, self.seed)
        pathname += "_clip_{}".format(config["clip"])
        pathname += "_tau_{}".format(config["tau"])
        now = datetime.now()    
        dt_string = now.strftime("%d_%m_%Y_%H:%M:%S")
        pathname += dt_string
        tensorboard_name = str(config["locexp"]) + '/runs/' + pathname
        self.vid_path = str(config["locexp"]) + '/vid'
        self.writer = SummaryWriter(tensorboard_name)
        print("summery writer ", tensorboard_name)
        self.average_prediction = deque(maxlen=100)
        self.average_same_action = deque(maxlen=100)
        self.all_actions = []
        for a in range(self.action_size):
            action = torch.Tensor(1) * 0 +  a
            self.all_actions.append(action.to(self.device))
    
    def learn(self, memory_ex, memory_all):
        self.steps += 1
        logging.debug("--------------------------New update-----------------------------------------------")
        states, next_states, actions, dones = memory_ex.expert_policy(self.batch_size)
        self.state_action_frq(states, actions)
        states, next_states, actions, dones = memory_all.expert_policy(self.batch_size)
        self.compute_shift_function(states, next_states, actions, dones)
        self.compute_r_function(states, actions)
        self.compute_q_function(states, next_states, actions, dones)
        self.soft_update(self.R_local, self.R_target, self.tau)
        self.soft_update(self.q_shift_local, self.q_shift_target, self.tau)
        self.soft_update(self.qnetwork_local, self.qnetwork_target, self.tau)         
        return
        
        
    
    def compute_q_function(self, states, next_states, actions, dones):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        qf1, qf2 = self.qnetwork_local(states)
        q_value1 = qf1.gather(1, actions)
        q_value2 = qf2.gather(1, actions)
        
        with torch.no_grad():
            q1_target, q2_target = self.qnetwork_target(next_states)
            min_q_target = torch.min(q1_target, q2_target)
            next_action_prob, next_action_log_prob = self.policy(next_states)
            next_q_target = (next_action_prob * (min_q_target - self.alpha * next_action_log_prob)).sum(dim=1, keepdim=True)
            rewards = self.R_target(states).detach().gather(1, actions.detach()).squeeze(0)
            Q_targets = rewards + ((1 - dones) * self.gamma * next_q_target)
        
        loss = F.mse_loss(q_value2, Q_targets.detach()) + F.mse_loss(q_value1, Q_targets.detach())
        
        # Get max predicted Q values (for next states) from target model
        
        self.writer.add_scalar('losss/q_loss', loss, self.steps)
        # Minimize the loss
        self.optimizer.zero_grad()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.qnetwork_local.parameters(), 1)
        self.optimizer.step()

        # --------------------------update-policy--------------------------------------------------------
        action_prob, log_action_prob = self.policy(states)
        with torch.no_grad():
            q_pi1, q_pi2 = self.qnetwork_local(states)
            min_q_values = torch.min(q_pi1, q_pi2)
        #policy_loss = (action_prob *  ((self.alpha * log_action_prob) - min_q_values).detach()).sum(dim=1).mean()
        policy_loss = (action_prob *  ((self.alpha * log_action_prob) - min_q_values)).sum(dim=1).mean()
        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()
        self.writer.add_scalar('loss/policy', policy_loss, self.steps)

        # --------------------------update-alpha--------------------------------------------------------
        alpha_loss =(action_prob.detach() *  (-self.log_alpha * (log_action_prob + self.target_entropy).detach())).sum(dim=1).mean()
        self.alpha_optim.zero_grad()
        alpha_loss.backward()
        self.alpha_optim.step()
        self.writer.add_scalar('loss/alpha', alpha_loss, self.steps)
        self.alpha = self.log_alpha.exp()
        

    def compute_shift_function(self, states, next_states, actions, dones):
        """Update value parameters using given batch of experience tuples.
        Params
        ======
            experiences (Tuple[torch.Tensor]): tuple of (s, a, r, s', done) tuples
            gamma (float): discount factor
        """
        actions = actions.type(torch.int64)
        with torch.no_grad():
            # Get max predicted Q values (for next states) from target model
            #if self.double_dqn:
            qt1, qt2 = self.qnetwork_local(next_states)
            q_min = torch.min(qt1, qt2)
            max_q, max_actions = q_min.max(1)
            Q_targets_next1, Q_targets_next2 = self.qnetwork_target(next_states)
            Q_targets_next = torch.min(Q_targets_next1, Q_targets_next2)
            Q_targets_next = Q_targets_next.gather(1, max_actions.type(torch.int64).unsqueeze(1))
            #else:
            #Q_targets_next = self.qnetwork_target(next_states).detach().max(1)[0].unsqueeze(1)
            # Compute Q targets for current states
            Q_targets = self.gamma * Q_targets_next * (dones)

        # Get expected Q values from local model
        Q_expected = self.q_shift_local(states).gather(1, actions)
        # Compute loss
        loss = F.mse_loss(Q_expected, Q_targets.detach())
        # Minimize the loss
        self.optimizer_shift.zero_grad()
        loss.backward()
        self.writer.add_scalar('Shift_loss', loss, self.steps)
        self.optimizer_shift.step()
    
    
    def compute_r_function(self, states, actions, debug=False, log=False):
        actions = actions.type(torch.int64)
        # sum all other actions
        # print("state shape ", states.shape)
        size = states.shape[0]
        idx = 0
        all_zeros = [1 for i in range(actions.shape[0])]
        zeros = False
        y_shift = self.q_shift_target(states).gather(1, actions).detach()
        log_a = self.get_action_prob(states, actions).detach()
        y_r_part1 = log_a - y_shift
        y_r_part2 = torch.empty((size, 1), dtype=torch.float32).to(self.device)
        for a, s in zip(actions, states):
            y_h = 0
            taken_actions = 0
            for b in self.all_actions:
                b = b.type(torch.int64).unsqueeze(1)
                n_b = self.get_action_prob(s.unsqueeze(0), b)
                if torch.eq(a, b) or n_b is None:
                    continue
                taken_actions += 1
                y_s = self.q_shift_target(s.unsqueeze(0)).detach().gather(1, b).item()
                n_b = n_b.data.item() - y_s
                r_hat = self.R_target(s.unsqueeze(0)).gather(1, b).item()
                y_h += (r_hat - n_b)
                if log:
                    text = "a {} r _hat {:.2f} - n_b  {:.2f} | sh {:.2f} ".format(b.item(), r_hat, n_b, y_s)
                    logging.debug(text)
            if taken_actions == 0:
                all_zeros[idx] = 0
                zeros = True
                y_r_part2[idx] = 0.0
            else:
                y_r_part2[idx] = (1. / taken_actions) * y_h
            idx += 1
            y_r = y_r_part1 + y_r_part2
        # check if there are zeros (no update for this tuble) remove them from states and
        if zeros:
            #print(all_zeros)
            #print(states)
            #print(actions)
            mask = torch.BoolTensor(all_zeros)
            states = states[mask]
            actions = actions[mask]
            y_r = y_r[mask]

        y = self.R_local(states).gather(1, actions)
        if log:
            text = "Action {:.2f} r target {:.2f} =  n_a {:.2f} + n_b {:.2f}  y {:.2f}".format(actions[0].item(), y_r[0].item(), y_r_part1[0].item(), y_r_part2[0].item(), y[0].item()) 
            logging.debug(text)


        r_loss = F.mse_loss(y, y_r.detach())

        # sys.exit()
        # Minimize the loss
        self.optimizer_r.zero_grad()
        r_loss.backward()
        # torch.nn.utils.clip_grad_norm_(self.R_local.parameters(), 5)
        self.optimizer_r.step()
        self.writer.add_scalar('Reward_loss', r_loss, self.steps)

    
    
    
    def get_action_prob(self, states, actions):
        """
        """
        actions = actions.type(torch.long)
        # check if action prob is zero
        output = self.predicter(states)
        output = F.softmax(output, dim=1)
        action_prob = output.gather(1, actions)
        action_prob = action_prob + torch.finfo(torch.float32).eps
        # check if one action if its to small
        if action_prob.shape[0] == 1:
            if action_prob.cpu().detach().numpy()[0][0] < 1e-4:
                return None
        action_prob = torch.log(action_prob)
        action_prob = torch.clamp(action_prob, min= self.clip, max=0)
        return action_prob


    def state_action_frq(self, states, action):
        """ Train classifer to compute state action freq
        """
        self.predicter.train()
        output = self.predicter(states, train=True)
        output = output.squeeze(0)
        # logging.debug("out predicter {})".format(output))

        y = action.type(torch.long).squeeze(1)
        #print("y shape", y.shape)
        loss = nn.CrossEntropyLoss()(output, y)
        self.optimizer_pre.zero_grad()
        loss.backward()
        #torch.nn.utils.clip_grad_norm_(self.predicter.parameters(), 1)
        self.optimizer_pre.step()
        self.writer.add_scalar('Predict_loss', loss, self.steps)
        self.predicter.eval()


    def test_predicter(self, memory):
        """

        """
        self.predicter.eval()
        same_state_predition = 0
        for i in range(memory.idx):
            states = memory.obses[i]
            actions = memory.actions[i]
        
            states = torch.as_tensor(states, device=self.device).unsqueeze(0)
            actions = torch.as_tensor(actions, device=self.device)
            output = self.predicter(states)   
            output = F.softmax(output, dim=1)
            #print("state 0", output.data)
            # create one hot encode y from actions
            y = actions.type(torch.long).item()
            p = torch.argmax(output.data).item()
            #print("a {}  p {}".format(y, p))
            text = "r  {}".format(self.R_local(states.detach()).detach()) 
            #print(text)
            if y==p:
                same_state_predition += 1
        text = "Same prediction {} of {} ".format(same_state_predition, memory.idx)
        print(text)
        logging.debug(text)




    def soft_update(self, local_model, target_model, tau=4):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target
        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter
        """
        # print("use tau", tau)
        for target_param, local_param in zip(target_model.parameters(), local_model.parameters()):
            target_param.data.copy_(tau * local_param.data + (1.0 - tau) * target_param.data)
    
    def load(self, filename):
        self.predicter.load_state_dict(torch.load(filename + "_predicter.pth"))
        self.optimizer_pre.load_state_dict(torch.load(filename + "_predicter_optimizer.pth"))
        self.R_local.load_state_dict(torch.load(filename + "_r_net.pth"))
        self.qnetwork_local.load_state_dict(torch.load(filename + "_q_net.pth"))
        
        print("Load models to {}".format(filename))


    def save(self, filename):
        """
        """
        mkdir("", filename)
        torch.save(self.predicter.state_dict(), filename + "_predicter.pth")
        torch.save(self.optimizer_pre.state_dict(), filename + "_predicter_optimizer.pth")
        torch.save(self.qnetwork_local.state_dict(), filename + "_q_net.pth")
        torch.save(self.optimizer.state_dict(), filename + "_q_net_optimizer.pth")
        torch.save(self.R_local.state_dict(), filename + "_r_net.pth")
        torch.save(self.q_shift_local.state_dict(), filename + "_q_shift_net.pth")

        print("save models to {}".format(filename))

    def test_q_value(self, memory):
        test_elements = memory.idx
        all_diff = 0
        error = True
        used_elements_r = 0
        used_elements_q = 0
        r_error = 0
        q_error = 0
        for i in range(test_elements):
            states = memory.obses[i]
            actions = memory.actions[i]
            states = torch.as_tensor(states, device=self.device).unsqueeze(0)
            actions = torch.as_tensor(actions, device=self.device)
            one_hot = torch.Tensor([0 for i in range(self.action_size)], device="cpu")
            one_hot[actions.item()] = 1
            with torch.no_grad():
                r_values = self.R_local(states)
                q_values1, q_values2 = self.qnetwork_local(states)
                q_values = torch.min(q_values1, q_values2)
                soft_r = F.softmax(r_values, dim=1).to("cpu")
                soft_q = F.softmax(q_values, dim=1).to("cpu")
                actions = actions.type(torch.int64)
                kl_q =  F.kl_div(soft_q.log(), one_hot, None, None, 'sum')
                kl_r =  F.kl_div(soft_r.log(), one_hot, None, None, 'sum')
                if kl_r == float("inf"):
                    pass
                else:
                    r_error += kl_r
                    used_elements_r += 1
                if kl_q == float("inf"):
                    pass
                else:
                    q_error += kl_q
                    used_elements_q += 1
                    
        average_q_kl = q_error / used_elements_q
        average_r_kl = r_error / used_elements_r
        text = "Kl div of Reward {} of {} elements".format(average_q_kl, used_elements_r)
        print(text)
        text = "Kl div of Q_values {} of {} elements".format(average_r_kl, used_elements_q)
        print(text)
        self.writer.add_scalar('KL_reward', average_r_kl, self.steps)
        self.writer.add_scalar('KL_q_values', average_q_kl, self.steps)


    def act(self, state):
        with torch.no_grad():
            state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
            action_prob, _ = self.policy(state)
            action = torch.argmax(action_prob)
            action = action.cpu().numpy()
        return action 


    def eval_policy(self, record=False, eval_episodes=4):
        if record:
            env = wrappers.Monitor(self.env, str(self.vid_path) + "/{}".format(self.steps), video_callable=lambda episode_id: True, force=True)
        else:
            env = self.env

        average_reward = 0
        scores_window = deque(maxlen=100)
        s = 0
        for i_epiosde in range(eval_episodes):
            episode_reward = 0
            state = env.reset()
            while True:
                s += 1
                action = self.act(state)
                state, reward, done, _ = env.step(action)
                episode_reward += reward
                if done:
                    break
            scores_window.append(episode_reward)
        if record:
            return 
        average_reward = np.mean(scores_window)
        print("Eval Episode {}  average Reward {} ".format(eval_episodes, average_reward))
        self.writer.add_scalar('Eval_reward', average_reward, self.steps)
def main():
  args = parser.parse_args()

  # model
  model = Classifier(args.channels)
  optimizer = optim.SGD(
    model.parameters(), lr=0.05, momentum=0.9, weight_decay=0.0001, nesterov=True)
  scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epoch)

  if args.gpu is not None:
    model.cuda(args.gpu)

  # dataset
  raw_loader = torch.utils.data.DataLoader(
    Dataset(os.path.join(DATA_DIR, 'raw')),
    args.batch // 2, shuffle=True, drop_last=True)
  noised_loader = torch.utils.data.DataLoader(
    Dataset(os.path.join(DATA_DIR, 'noised_tgt')),
    args.batch // 2, shuffle=True, drop_last=True)

  # train
  for epoch in range(args.epoch):
    loss = 0
    accuracy = 0
    count = 0

    for x0, x1 in zip(noised_loader, raw_loader):
      if args.gpu is not None:
        x0 = x0.cuda(args.gpu)
        x1 = x1.cuda(args.gpu)

      # train
      model.train()

      x = torch.cat((x0, x1), dim=0)  # @UndefinedVariable
      t = torch.zeros((x.shape[0], 2), device=x.device).float()  # @UndefinedVariable

      t[:x0.shape[0], 0] = 1
      t[x0.shape[0]:, 1] = 1

      x, t = mixup(x, t)
      y = model(x)
      e = (-1 * nn.functional.log_softmax(y, dim=1) * t).sum(dim=1).mean()

      optimizer.zero_grad()
      e.backward()
      optimizer.step()

      # validate
      model.eval()

      with torch.no_grad():
        y0 = (model(x0).max(dim=1)[1] == 0).float()
        y1 = (model(x1).max(dim=1)[1] == 1).float()

      a = torch.cat((y0, y1), dim=0).mean()  # @UndefinedVariable

      loss += float(e) * len(x)
      accuracy += float(a) * len(x)
      count += len(x)

    print('[{}] lr={:.7f}, loss={:.4f}, accuracy={:.4f}'.format(
      epoch, float(optimizer.param_groups[0]['lr']), loss / count, accuracy / count),
      flush=True)

    scheduler.step()

    snapshot = {'channels': args.channels, 'model': model.state_dict()}
    torch.save(snapshot, '{}.tmp'.format(args.file))
    os.rename('{}.tmp'.format(args.file), args.file)
Ejemplo n.º 10
0
class MLCModel:
    """Summary
    Attributes:
        cfg (TYPE): Description
        criterion (TYPE): Description
        device (TYPE): cpu or gpu
        hparams (TYPE): hyper parameters from parser
        labels (TYPE): list of the diseases, see init_labels()
        model (TYPE): feature extraction backbone with classifier, see cfg.json
        names (TYPE): list of filenames in the images which have been from dataloader
        num_tasks (TYPE): 5 or 14, number of diseases
    """
    def __init__(self, hparams):
        """Summary

        Args:
            hparams (TYPE): hyper parameters from parser
        """
        super(MLCModel, self).__init__()
        self.hparams = hparams
        self.device = torch.device("cuda:{}".format(hparams.gpus) if torch.
                                   cuda.is_available() else "cpu")

        with open(self.hparams.json_path, 'r') as f:
            self.cfg = edict(json.load(f))
            hparams_dict = vars(self.hparams)
            self.cfg['hparams'] = hparams_dict
            if self.hparams.verbose is True:
                print(json.dumps(self.cfg, indent=4))

        if self.cfg.criterion in ['bce', 'focal', 'sce', 'bce_v2', 'bfocal']:
            self.criterion = init_loss_func(self.cfg.criterion,
                                            device=self.device)
        elif self.cfg.criterion == 'class_balance':
            samples_per_cls = list(
                map(int, self.cfg.samples_per_cls.split(',')))
            self.criterion = init_loss_func(self.cfg.criterion,
                                            samples_per_cls=samples_per_cls,
                                            loss_type=self.cfg.loss_type)
        else:
            self.criterion = init_loss_func(self.cfg.criterion)

        self.labels = init_labels(name=self.hparams.data_name)
        if self.cfg.extract_fields is None:
            self.cfg.extract_fields = ','.join(
                [str(idx) for idx in range(len(self.labels))])
        else:
            assert isinstance(self.cfg.extract_fields,
                              str), "extract_fields must be string!"

        self.model = Classifier(self.cfg, self.hparams)
        self.state_dict = None
        # Load cross-model from other configuration
        if self.hparams.load is not None and len(self.hparams.load) > 0:
            if not os.path.exists(hparams.load):
                raise ValueError('{} does not exists!'.format(hparams.load))
            state_dict = load_state_dict(self.hparams.load, self.model,
                                         self.device)
            self.state_dict = state_dict

        # DataParallel model
        if torch.cuda.device_count() > 1 and self.hparams.gpus == 0:
            self.model = nn.DataParallel(self.model)

        self.model.to(device=self.device)
        self.num_tasks = list(map(int, self.cfg.extract_fields.split(',')))
        self.names = list()
        self.optimizer, self.scheduler = self.configure_optimizers()
        self.train_loader = self.train_dataloader()
        self.valid_loader = self.val_dataloader()
        self.test_loader = self.test_dataloader()

    def forward(self, x):
        """Summary

        Args:
            x (TYPE): image

        Returns:
            TYPE: Description
        """
        return self.model(x)

    def train(self):
        epoch_start = 0

        summary_train = {
            'epoch': 0,
            'step': 0,
            'total_step': len(self.train_loader)
        }
        summary_dev = {'loss': float('inf'), 'score': 0.0}
        best_dict = {
            "score_dev_best": 0.0,
            "loss_dev_best": float('inf'),
            "score_top_k": [0.0],
            "loss_top_k": [0.0],
            "score_curr_idx": 0,
            "loss_curr_idx": 0
        }

        if self.state_dict is not None:
            summary_train = {
                'epoch': self.state_dict['epoch'],
                'step': self.state_dict['step'],
                'total_step': len(self.train_loader)
            }
            best_dict['score_dev_best'] = self.state_dict['score_dev_best']
            best_dict['loss_dev_best'] = self.state_dict['loss_dev_best']
            epoch_start = self.state_dict['epoch']

        for epoch in range(epoch_start, self.hparams.epochs):
            lr = self.create_scheduler(start_epoch=summary_train['epoch'])
            for param_group in self.optimizer.param_groups:
                param_group['lr'] = lr

            logging.info('Learning rate in epoch {}: {}'.format(
                epoch + 1, self.optimizer.param_groups[0]['lr']))
            print('Learning rate in epoch {}: {}'.format(
                epoch + 1, self.optimizer.param_groups[0]['lr']))

            summary_train, best_dict = self.training_step(
                summary_train, summary_dev, best_dict)

            self.validation_end(summary_dev, summary_train, best_dict)

            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'score_dev_best': best_dict['score_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': self.model.state_dict()
                },
                os.path.join(self.hparams.save_path,
                             '{}_model.pth'.format(summary_train['epoch'] -
                                                   1)))

        logging.info('Training finished, model saved')
        print('Training finished, model saved')

    # def training_step(self, batch, batch_nb):
    def training_step(self, summary_train, summary_dev, best_dict):
        """Summary
        Extract the batch of datapoints and return the predicted logits
        Args:
            summary_train:
            summary_dev:
            best_dict:
        Returns:
            TYPE: Description
        """
        losses = AverageMeter()
        torch.set_grad_enabled(True)
        self.model.train()
        time_now = time.time()

        for i, (inputs, target, _) in enumerate(self.train_loader):
            if isinstance(inputs, tuple):
                inputs = tuple([
                    e.to(self.device) if type(e) == torch.Tensor else e
                    for e in inputs
                ])
            else:
                inputs = inputs.to(self.device)
            target = target.to(self.device)
            self.optimizer.zero_grad()

            if self.cfg.no_jsd:
                if self.cfg.n_crops:
                    bs, n_crops, c, h, w = inputs.size()
                    inputs = inputs.view(-1, c, h, w)

                    if len(self.hparams.mixtype) > 0:
                        if self.hparams.multi_cls:
                            target = target.view(target.size()[0], -1)
                            inputs, targets_a, targets_b, lam = self.mix_data(
                                inputs,
                                target.repeat(1, n_crops).view(-1),
                                self.device, self.hparams.alpha)
                        else:
                            inputs, targets_a, targets_b, lam = self.mix_data(
                                inputs,
                                target.repeat(1, n_crops).view(
                                    -1, len(self.num_tasks)), self.device,
                                self.hparams.alpha)

                    logits = self.forward(inputs)
                    if len(self.hparams.mixtype) > 0:
                        loss_func = self.mixup_criterion(
                            targets_a, targets_b, lam)
                        loss = loss_func(self.criterion, logits)
                    else:
                        if self.hparams.multi_cls:
                            target = target.view(target.size()[0], -1)
                            loss = self.criterion(
                                logits,
                                target.repeat(1, n_crops).view(-1))
                        else:
                            loss = self.criterion(
                                logits,
                                target.repeat(1, n_crops).view(
                                    -1, len(self.num_tasks)))
                else:
                    if len(self.hparams.mixtype) > 0:
                        inputs, targets_a, targets_b, lam = self.mix_data(
                            inputs, target, self.device, self.hparams.alpha)

                    logits = self.forward(inputs)
                    if len(self.hparams.mixtype) > 0:
                        loss_func = self.mixup_criterion(
                            targets_a, targets_b, lam)
                        loss = loss_func(self.criterion, logits)
                    else:
                        loss = self.criterion(logits, target)
            else:
                images_all = torch.cat(inputs, 0)
                logits_all = self.forward(images_all)
                logits_clean, logits_aug1, logits_aug2 = torch.split(
                    logits_all, inputs[0].size(0))

                # Cross-entropy is only computed on clean images
                loss = F.cross_entropy(logits_clean, target)

                p_clean, p_aug1, p_aug2 = F.softmax(
                    logits_clean,
                    dim=1), F.softmax(logits_aug1,
                                      dim=1), F.softmax(logits_aug2, dim=1)

                # Clamp mixture distribution to avoid exploding KL divergence
                p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7,
                                        1).log()
                loss += 12 * (
                    F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                    F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

            assert not np.isnan(
                loss.item()), 'Model diverged with losses = NaN'

            loss.backward()
            self.optimizer.step()
            summary_train['step'] += 1
            losses.update(loss.item(), target.size(0))

            if summary_train['step'] % self.hparams.log_every == 0:
                time_spent = time.time() - time_now
                time_now = time.time()
                logging.info('Train, '
                             'Epoch : {}, '
                             'Step : {}/{}, '
                             'Loss: {loss.val:.4f} ({loss.avg:.4f}), '
                             'Run Time : {runtime:.2f} sec'.format(
                                 summary_train['epoch'] + 1,
                                 summary_train['step'],
                                 summary_train['total_step'],
                                 loss=losses,
                                 runtime=time_spent))
                print('Train, '
                      'Epoch : {}, '
                      'Step : {}/{}, '
                      'Loss: {loss.val:.4f} ({loss.avg:.4f}), '
                      'Run Time : {runtime:.2f} sec'.format(
                          summary_train['epoch'] + 1,
                          summary_train['step'],
                          summary_train['total_step'],
                          loss=losses,
                          runtime=time_spent))

            if summary_train['step'] % self.hparams.test_every == 0:
                self.validation_end(summary_dev, summary_train, best_dict)

            self.model.train()
            torch.set_grad_enabled(True)

        summary_train['epoch'] += 1
        return summary_train, best_dict

    def validation_step(self, summary_dev):
        """Summary
        Extract the batch of datapoints and return the predicted logits in validation step
        Args:
            summary_dev (TYPE): Description

        Returns:
            TYPE: Description
        """
        losses = AverageMeter()
        torch.set_grad_enabled(False)
        self.model.eval()

        output_ = np.array([])
        target_ = np.array([])

        with torch.no_grad():
            for i, (inputs, target, _) in enumerate(self.valid_loader):
                target = target.to(self.device)
                if isinstance(inputs, tuple):
                    inputs = tuple([
                        e.to(self.device) if type(e) == torch.Tensor else e
                        for e in inputs
                    ])
                else:
                    inputs = inputs.to(self.device)

                logits = self.forward(inputs)
                loss = self.criterion(logits, target)
                losses.update(loss.item(), target.size(0))

                if self.hparams.multi_cls:
                    output = F.softmax(logits)
                    _, output = torch.max(output, 1)
                else:
                    output = torch.sigmoid(logits)

                target = target.detach().to('cpu').numpy()
                target_ = np.concatenate(
                    (target_, target), axis=0) if len(target_) > 0 else target
                y_pred = output.detach().to('cpu').numpy()
                output_ = np.concatenate(
                    (output_, y_pred), axis=0) if len(output_) > 0 else y_pred

        summary_dev['loss'] = losses.avg
        return summary_dev, output_, target_

    def validation_end(self, summary_dev, summary_train, best_dict):
        """Summary
        After the validation end, calculate the metrics
        Args:
            summary_dev (TYPE): Description
            summary_train (TYPE): Description
            best_dict (TYPE): Description

        Returns:
            TYPE: Description
        """
        time_now = time.time()
        summary_dev, output_, target_ = self.validation_step(summary_dev)
        time_spent = time.time() - time_now

        if not self.hparams.auto_threshold:
            overall_pre, overall_rec, overall_fscore = get_metrics(
                copy.deepcopy(output_), target_, self.cfg.beta,
                self.cfg.threshold, self.cfg.metric_type)
        else:
            overall_pre, overall_rec, overall_fscore = self.find_best_fixed_threshold(
                output_, target_)

        resp = dict()
        if not self.hparams.multi_cls:
            for t in range(len(self.num_tasks)):
                y_pred = np.transpose(output_)[t]
                precision, recall, f_score = get_metrics(
                    copy.deepcopy(y_pred),
                    np.transpose(target_)[t], self.cfg.beta,
                    self.cfg.threshold, 'binary')

                resp['precision_{}'.format(
                    self.labels[self.num_tasks[t]])] = precision
                resp['recall_{}'.format(
                    self.labels[self.num_tasks[t]])] = recall
                resp['f_score_{}'.format(
                    self.labels[self.num_tasks[t]])] = f_score

        resp['overall_precision'] = overall_pre
        resp['overall_recall'] = overall_rec
        resp['overall_f_score'] = overall_fscore

        logging.info(
            'Dev, Step : {}/{}, Loss : {}, Fscore : {:.3f}, Precision : {:.3f}, '
            'Recall : {:.3f}, Run Time : {:.2f} sec'.format(
                summary_train['step'], summary_train['total_step'],
                summary_dev['loss'], resp['overall_f_score'],
                resp['overall_precision'], resp['overall_recall'], time_spent))
        print(
            'Dev, Step : {}/{}, Loss : {}, Fscore : {:.3f}, Precision : {:.3f}, '
            'Recall : {:.3f}, Run Time : {:.2f} sec'.format(
                summary_train['step'], summary_train['total_step'],
                summary_dev['loss'], resp['overall_f_score'],
                resp['overall_precision'], resp['overall_recall'], time_spent))

        save_best = False
        mean_score = resp['overall_f_score']
        if mean_score > min(best_dict['score_top_k']):
            self.update_top_k(mean_score, best_dict, 'score')
            if self.hparams.metric == 'score':
                save_best = True

        mean_loss = summary_dev['loss']
        if mean_loss < max(best_dict['loss_top_k']):
            self.update_top_k(mean_loss, best_dict, 'loss')
            if self.hparams.metric == 'loss':
                save_best = True

        if save_best:
            torch.save(
                {
                    'epoch': summary_train['epoch'],
                    'step': summary_train['step'],
                    'score_dev_best': best_dict['score_dev_best'],
                    'loss_dev_best': best_dict['loss_dev_best'],
                    'state_dict': self.model.state_dict()
                },
                os.path.join(self.hparams.save_path,
                             'best{}.pth'.format(best_dict['score_curr_idx'])))

            logging.info(
                'Best {}, Step : {}/{}, Loss : {}, Score : {:.3f}'.format(
                    best_dict['score_curr_idx'], summary_train['step'],
                    summary_train['total_step'], summary_dev['loss'],
                    best_dict['score_dev_best']))

            print('Best {}, Step : {}/{}, Loss : {}, Score : {:.3f}'.format(
                best_dict['score_curr_idx'], summary_train['step'],
                summary_train['total_step'], summary_dev['loss'],
                best_dict['score_dev_best']))

    def find_best_fixed_threshold(self, output_, target_):
        score = list()
        thrs = np.arange(0, 1.0, 0.01)
        pre_rec = list()
        for thr in tqdm.tqdm(thrs):
            pre, rec, fscore = get_metrics(copy.deepcopy(output_),
                                           copy.deepcopy(target_),
                                           self.cfg.beta, thr,
                                           self.cfg.metric_type)
            score.append(fscore)
            pre_rec.append([pre, rec])

        score = np.array(score)
        pm = score.argmax()
        best_thr, best_score = thrs[pm], score[pm].item()
        best_pre, best_rec = pre_rec[pm]
        print('thr={} F2={} prec{} rec{}'.format(best_thr, best_score,
                                                 best_pre, best_rec))
        return best_pre, best_rec, best_score

    def test_step(self):
        """Summary
        Extract the batch of datapoints and return the predicted logits in test step
        Args:
            batch (TYPE): Description
            batch_nb (TYPE): Description

        Returns:
            TYPE: Description
        """
        torch.set_grad_enabled(False)
        self.model.eval()

        output_ = np.array([])
        target_ = np.array([])

        with torch.no_grad():
            for i, batch in enumerate(self.test_loader):
                if self.hparams.infer == 'valid':
                    # Evaluate
                    inputs, target, names = batch
                    target = target.to(self.device)
                else:
                    # Test
                    inputs, names = batch

                if isinstance(inputs, tuple):
                    inputs = tuple([
                        e.to(self.device) if type(e) == torch.Tensor else e
                        for e in inputs
                    ])
                else:
                    inputs = inputs.to(self.device)

                self.names.extend(names)

                if self.cfg.n_crops:
                    bs, n_crops, c, h, w = inputs.size()
                    inputs = inputs.view(-1, c, h, w)

                logits = self.forward(inputs)

                if self.cfg.n_crops:
                    logits = logits.view(bs, n_crops, -1).mean(1)

                if self.hparams.multi_cls:
                    output = F.softmax(logits)
                    output = output[:, 1]
                else:
                    output = torch.sigmoid(logits)

                if self.hparams.infer == 'valid':
                    target = target.detach().to('cpu').numpy()
                    target_ = np.concatenate((target_, target), axis=0) if len(target_) > 0 else \
                        target
                y_pred = output.detach().to('cpu').numpy()
                output_ = np.concatenate(
                    (output_, y_pred), axis=0) if len(output_) > 0 else y_pred

        if self.hparams.infer == 'valid':
            return output_, target_
        else:
            return output_

    def test(self):
        """Summary
        After the test end, calculate the metrics
        Args:

        Returns:
            TYPE: Description
        """
        # inference dataset
        if self.hparams.infer == 'valid':
            output_, target_ = self.test_step()
        else:
            output_ = self.test_step()

        resp = dict()
        to_csv = {'Images': self.names}

        for t in range(len(self.num_tasks)):
            if self.hparams.multi_cls:
                y_pred = np.reshape(output_, output_.shape[0])
            else:
                y_pred = np.transpose(output_)[t]
            to_csv[self.labels[self.num_tasks[t]]] = y_pred

        # Only save scores to json file when in valid mode
        if self.hparams.infer == 'valid':
            overall_pre, overall_rec, overall_fscore = get_metrics(
                copy.deepcopy(output_), copy.deepcopy(target_), self.cfg.beta,
                self.cfg.threshold, self.cfg.metric_type)

            resp['overall_pre'] = overall_pre
            resp['overall_rec'] = overall_rec
            resp['overall_f_score'] = overall_fscore

            with open(
                    os.path.join(os.path.dirname(self.hparams.load),
                                 'scores_{}.csv'.format(uuid.uuid4())),
                    'w') as f:
                json.dump(resp, f)

        # Save predictions to csv file for computing metrics in off-line mode
        path_df = DataFrame(to_csv, columns=to_csv.keys())
        path_df.to_csv(os.path.join(os.path.dirname(self.hparams.load),
                                    'predictions_{}.csv'.format(uuid.uuid4())),
                       index=False)
        return resp

    def configure_optimizers(self):
        """Summary
        Must be implemented
        Returns:
            TYPE: Description
        """
        optimizer = create_optimizer(self.cfg, self.model.parameters())

        if self.cfg.lr_scheduler == 'step':
            scheduler = optim.lr_scheduler.StepLR(optimizer,
                                                  step_size=self.cfg.step_size,
                                                  gamma=self.cfg.lr_factor)
        elif self.cfg.lr_scheduler == 'cosin':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                             T_max=200,
                                                             eta_min=1e-6)
        elif self.cfg.lr_scheduler == 'cosin_epoch':
            scheduler = optim.lr_scheduler.CosineAnnealingLR(
                optimizer, T_max=self.cfg.tmax, eta_min=self.cfg.eta_min)
        elif self.cfg.lr_scheduler == 'onecycle':
            max_lr = [g["lr"] for g in optimizer.param_groups]
            scheduler = optim.lr_scheduler.OneCycleLR(
                optimizer,
                max_lr=max_lr,
                epochs=self.hparams.epochs,
                steps_per_epoch=len(self.train_dataloader()))
            scheduler = {"scheduler": scheduler, "interval": "step"}
        else:
            raise ValueError(
                'Does not support {} learning rate scheduler'.format(
                    self.cfg.lr_scheduler))
        return optimizer, scheduler

    def train_dataloader(self):
        """Summary
        Return the train dataset, see dataflow/__init__.py
        Returns:
            TYPE: Description
        """
        ds_train = init_dataset(self.hparams.data_name,
                                cfg=self.cfg,
                                data_path=self.hparams.data_path,
                                mode='train')
        return DataLoader(dataset=ds_train,
                          batch_size=self.cfg.train_batch_size,
                          shuffle=True,
                          num_workers=self.hparams.num_workers,
                          pin_memory=True)

    def val_dataloader(self):
        """Summary
        Return the val dataset, see dataflow/__init__.py
        Returns:
            TYPE: Description
        """
        ds_val = init_dataset(self.hparams.data_name,
                              cfg=self.cfg,
                              data_path=self.hparams.data_path,
                              mode='valid')
        return DataLoader(dataset=ds_val,
                          batch_size=self.cfg.dev_batch_size,
                          shuffle=False,
                          num_workers=self.hparams.num_workers,
                          pin_memory=True)

    def test_dataloader(self):
        """Summary
        Return the test dataset, see dataflow/__init__.py
        Returns:
            TYPE: Description
        """
        ds_test = init_dataset(self.hparams.data_name,
                               cfg=self.cfg,
                               data_path=self.hparams.data_path,
                               mode=self.hparams.infer)
        return DataLoader(dataset=ds_test,
                          batch_size=self.cfg.dev_batch_size,
                          shuffle=False,
                          num_workers=self.hparams.num_workers,
                          pin_memory=True)

    def mix_data(self, x, y, device, alpha=1.0):
        """
        Re-constructed input images and labels based on one of two regularization methods such as Mixup and Cutmix.
        :param x: input images
        :param y: labels
        :param device: cpu or gpu device
        :param alpha: parameter for beta distribution
        :return: mixed inputs, pairs of targets, and lambda
        """
        if alpha > 0.:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1.
        batch_size = x.size()[0]
        index = torch.randperm(batch_size).to(device)
        y_a, y_b = y, y[index]

        if self.hparams.mixtype == 'mixup':
            mixed_x = lam * x + (1 - lam) * x[index, :]
        elif self.hparams.mixtype == 'cutmix':
            bbx1, bby1, bbx2, bby2 = self.rand_bbox(x.size(), lam)
            x[:, :, bbx1:bbx2, bby1:bby2] = x[index, :, bbx1:bbx2, bby1:bby2]
            mixed_x = x
            lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) /
                       (x.size()[-1] * x.size()[-2]))
        else:
            raise ValueError('Mixtype {} does not exists'.format(
                self.hparams.mixtype))

        return mixed_x, y_a, y_b, lam

    def create_scheduler(self, start_epoch):
        """
            Learning rate schedule with respect to epoch
            lr: float, initial learning rate
            lr_factor: float, decreasing factor every epoch_lr
            epoch_now: int, the current epoch
            lr_epochs: list of int, decreasing every epoch in lr_epochs
            return: lr, float, scheduled learning rate.
            """
        count = 0
        for epoch in self.hparams.lr_epochs.split(','):
            if start_epoch >= int(epoch):
                count += 1
                continue

            break

        return self.cfg.lr * np.power(self.cfg.lr_factor, count)

    @staticmethod
    def mixup_criterion(y_a, y_b, lam):
        """
        Re-constructured loss function based on regularization technique
        Args:
            y_a: original labels
            y_b: shuffled labels after random permutation
            lam: generated point in beta distribution
        Returns:
            Combined loss function
        """
        return lambda criterion, pred: lam * criterion(pred, y_a) + (
            1 - lam) * criterion(pred, y_b)

    @staticmethod
    def rand_bbox(size, lam):
        """
        Generate random bounding box for specified cutting rate
        Args:
            size: image size including weight and height
            lam: generated point in beta distribution
        Returns:
            Coordinates of top-left and right-bottom vertices of bounding box
        """
        W = size[2]
        H = size[3]
        cut_rat = np.sqrt(1. - lam)
        cut_w = np.int(W * cut_rat)
        cut_h = np.int(H * cut_rat)

        # uniform
        cx = np.random.randint(W)
        cy = np.random.randint(H)

        bbx1 = np.clip(cx - cut_w // 2, 0, W)
        bby1 = np.clip(cy - cut_h // 2, 0, H)
        bbx2 = np.clip(cx + cut_w // 2, 0, W)
        bby2 = np.clip(cy + cut_h // 2, 0, H)

        return bbx1, bby1, bbx2, bby2

    def update_top_k(self, mean, best_dict, metric):
        metric_dev_best = '{}_dev_best'.format(metric)
        metric_top_k = '{}_top_k'.format(metric)
        metric_curr_idx = '{}_curr_idx'.format(metric)

        if metric == 'loss':
            if mean < best_dict[metric_dev_best]:
                best_dict[metric_dev_best] = mean
        else:
            if mean > best_dict[metric_dev_best]:
                best_dict[metric_dev_best] = mean

        if len(best_dict[metric_top_k]) >= self.hparams.save_top_k:
            if metric == 'loss':
                min_idx = best_dict[metric_top_k].index(
                    max(best_dict[metric_top_k]))
            else:
                min_idx = best_dict[metric_top_k].index(
                    min(best_dict[metric_top_k]))
            curr_idx = min_idx
            best_dict[metric_top_k][min_idx] = mean
        else:
            curr_idx = len(best_dict[metric_top_k])
            best_dict[metric_top_k].append(mean)

        best_dict[metric_curr_idx] = curr_idx
Ejemplo n.º 11
0
def train_sentiment(opts):
    
    device = torch.device("cuda" if use_cuda else "cpu")

    glove_loader = GloveLoader(os.path.join(opts.data_dir, 'glove', opts.glove_emb_file))
    train_loader = DataLoader(RottenTomatoesReviewDataset(opts.data_dir, 'train', glove_loader, opts.maxlen), \
        batch_size=opts.bsize, shuffle=True, num_workers=opts.nworkers)
    valid_loader = DataLoader(RottenTomatoesReviewDataset(opts.data_dir, 'val', glove_loader, opts.maxlen), \
        batch_size=opts.bsize, shuffle=False, num_workers=opts.nworkers)
    model = Classifier(opts.hidden_size, opts.dropout_p, glove_loader, opts.enc_arch)

    if opts.optim == 'adam':
        optimizer = torch.optim.Adam(model.parameters(), lr=opts.lr, weight_decay=opts.wd)
    else:
        raise NotImplementedError("Unknown optim type")

    criterion = nn.CrossEntropyLoss()

    start_n_iter = 0
    # for choosing the best model
    best_val_acc = 0.0

    model_path = os.path.join(opts.save_path, 'model_latest.net')
    if opts.resume and os.path.exists(model_path):
        # restoring training from save_state
        print ('====> Resuming training from previous checkpoint')
        save_state = torch.load(model_path, map_location='cpu')
        model.load_state_dict(save_state['state_dict'])
        start_n_iter = save_state['n_iter']
        best_val_acc = save_state['best_val_acc']
        opts = save_state['opts']
        opts.start_epoch = save_state['epoch'] + 1

    model = model.to(device)

    # for logging
    logger = TensorboardXLogger(opts.start_epoch, opts.log_iter, opts.log_dir)
    logger.set(['acc', 'loss'])
    logger.n_iter = start_n_iter

    for epoch in range(opts.start_epoch, opts.epochs):
        model.train()
        logger.step()

        for batch_idx, data in enumerate(train_loader):
            acc, loss = run_iter(opts, data, model, criterion, device)

            # optimizer step
            optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(model.parameters(), opts.max_norm)
            optimizer.step()

            logger.update(acc, loss)

        val_loss, val_acc, time_taken = evaluate(opts, model, valid_loader, criterion, device)
        # log the validation losses
        logger.log_valid(time_taken, val_acc, val_loss)
        print ('')

        # Save the model to disk
        if val_acc >= best_val_acc:
            best_val_acc = val_acc
            save_state = {
                'epoch': epoch,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'n_iter': logger.n_iter,
                'opts': opts,
                'val_acc': val_acc,
                'best_val_acc': best_val_acc
            }
            model_path = os.path.join(opts.save_path, 'model_best.net')
            torch.save(save_state, model_path)

        save_state = {
            'epoch': epoch,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'n_iter': logger.n_iter,
            'opts': opts,
            'val_acc': val_acc,
            'best_val_acc': best_val_acc
        }
        model_path = os.path.join(opts.save_path, 'model_latest.net')
        torch.save(save_state, model_path)
Ejemplo n.º 12
0
        for index, (source, target) in enumerate(zip(source_loader, target_loader)):
            p = float(index + start_steps) / total_steps
            res = train(feature_extrator, class_classifier,domain_classifier, source,target, optimizer, index + start_steps)
            training_sta.append(res)

        test_source = test(feature_extrator,class_classifier, s_test_loader, epoch)
        test_target = test(feature_extrator, class_classifier, t_test_loader, epoch)

        test_s_sta.append(test_source)
        test_t_sta.append(test_target)
        print('###Test Source: Epoch: {}, avg_loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
            epoch + 1,
            test_source['average_loss'],
            test_source['correct'],
            test_source['total'],
            test_source['accuracy'],
        ))
        print('###Test Target: Epoch: {}, avg_loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
            epoch + 1,
            test_target['average_loss'],
            test_target['correct'],
            test_target['total'],
            test_target['accuracy'],
        ))
    result_path = 'result_norm_dann'
    import os
    os.makedirs(result_path, exist_ok=True)
    torch.save([domain_classifier.state_dict(),feature_extrator.state_dict(),class_classifier.state_dict()], result_path + '/checkpoint.tar')
    save(training_sta, result_path + '/training_state.pkl')
    save(test_s_sta, result_path + '/test_s_sta.pkl')
    save(test_t_sta, result_path + '/test_t_sta.pkl')
    epoch = 0

    running_loss = 1.0

    # Simple training for 500 epochs

    for epoch in range(500):
        
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs = data[0].to(device)
            labels = data[1].to(device)

            optimizer.zero_grad()

            outputs = net(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            if i % 200 == 199:
                print('[%d, %5d] loss: %.3f' %
                    (epoch + 1, i + 1, running_loss / 200))
                running_loss = 0.0
        epoch += 1

    print('Finished Training')
    # saves Classifier
    torch.save(net.state_dict(), PATH)
Ejemplo n.º 14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--word-dim', type=int, default=300, help='size of word embeddings')
    parser.add_argument('--hidden-dim', type=int, default=300, help='number of hidden units per layer')
    parser.add_argument('--num-layers', type=int, default=1, help='number of layers in BiLSTM')
    parser.add_argument('--att-dim', type=int, default=350, help='number of attention unit')
    parser.add_argument('--att-hops', type=int, default=4, help='number of attention hops, for multi-hop attention model')
    parser.add_argument('--clf-hidden-dim', type=int, default=512, help='hidden (fully connected) layer size for classifier MLP')
    parser.add_argument('--clip', type=float, default=0.5, help='clip to prevent the too large grad in LSTM')
    parser.add_argument('--lr', type=float, default=.001, help='initial learning rate')
    parser.add_argument('--weight-decay', type=float, default=1e-5, help='weight decay rate per batch')
    parser.add_argument('--dropout', type=float, default=0.3)
    parser.add_argument('--max-epoch', type=int, default=8)
    parser.add_argument('--seed', type=int, default=666)
    parser.add_argument('--cuda', action='store_true', default=True)
    parser.add_argument('--optimizer', default='adam', choices=['adam', 'sgd'])
    parser.add_argument('--batch-size', type=int, default=32, help='batch size for training')
    parser.add_argument('--penalization-coeff', type=float, default=0.1, help='the penalization coefficient')
    parser.add_argument('--fix-word-embedding', action='store_true')


    parser.add_argument('--model-type', required=True, choices=['sa', 'avgblock', 'hard'])
    parser.add_argument('--data-type', required=True, choices=['age2', 'dbpedia', 'yahoo'])
    parser.add_argument('--data', required=True, help='pickle file obtained by dataset dump')
    parser.add_argument('--save-dir', type=str, required=True, help='path to save the final model')
    parser.add_argument('--block-size', type=int, default=-1, help='block size only when model-type is avgblock')
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    if torch.cuda.is_available():
        if not args.cuda:
            print("WARNING: You have a CUDA device, so you should probably run with --cuda")
        else:
            torch.cuda.manual_seed(args.seed)
    #######################################
    # a simple log file, the same content as stdout
    if not os.path.exists(args.save_dir):
        os.mkdir(args.save_dir)
    logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)-8s %(message)s')
    logFormatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
    rootLogger = logging.getLogger()
    fileHandler = logging.FileHandler(os.path.join(args.save_dir, 'stdout.log'))
    fileHandler.setFormatter(logFormatter)
    rootLogger.addHandler(fileHandler)
    ########################################
    for k, v in vars(args).items():
        logging.info(k+':'+str(v))

    #####################################################################
    if args.data_type == 'age2':
        data = AGE2(datapath=args.data, batch_size=args.batch_size)
        num_classes = 5
    elif args.data_type == 'dbpedia':
        data = DBpedia(datapath=args.data, batch_size=args.batch_size)
        num_classes = 14
    elif args.data_type == 'yahoo':
        data = Yahoo(datapath=args.data, batch_size=args.batch_size)
        num_classes = 10
    else:
        raise Exception('Invalid argument data-type')
    #####################################################################
    if args.model_type == 'avgblock':
        assert args.block_size > 0
    #####################################################################


    tic = time.time()
    model = Classifier(
        dictionary=data,
        dropout=args.dropout,
        num_words=data.num_words,
        num_layers=args.num_layers,
        hidden_dim=args.hidden_dim,
        word_dim=args.word_dim,
        att_dim=args.att_dim,
        att_hops=args.att_hops,
        clf_hidden_dim=args.clf_hidden_dim,
        num_classes=num_classes,
        model_type=args.model_type,
        block_size=args.block_size,
    )
    print('It takes %.2f sec to build the model.' % (time.time() - tic))
    logging.info(model)

    model.word_embedding.weight.data.set_(data.weight)
    if args.fix_word_embedding:
        model.word_embedding.weight.requires_grad = False
    if args.cuda:
        model = model.cuda()
    ''' count parameters
    num_params = sum(np.prod(p.size()) for p in model.parameters())
    num_embedding_params = np.prod(model.word_embedding.weight.size())
    print('# of parameters: %d' % num_params)
    print('# of word embedding parameters: %d' % num_embedding_params)
    print('# of parameters (excluding word embeddings): %d' % (num_params - num_embedding_params))
    '''
    if args.optimizer == 'adam':
        optimizer_class = optim.Adam
    elif args.optimizer == 'sgd':
        optimizer_class = optim.SGD
    else:
        raise Exception('For other optimizers, please add it yourself. supported ones are: SGD and Adam.')
    params = [p for p in model.parameters() if p.requires_grad]
    optimizer = optimizer_class(params=params, lr=args.lr, weight_decay=args.weight_decay)
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer=optimizer, mode='max', factor=0.5, patience=10, verbose=True)
    criterion = nn.CrossEntropyLoss()
    # Identity matrix for each batch
    I = Variable(torch.eye(args.att_hops).unsqueeze(0).expand(args.batch_size, -1, -1))
    if args.cuda:
        I = I.cuda()
    trpack = {
            'model': model,
            'params': params, 
            'criterion': criterion, 
            'optimizer': optimizer,
            'I': I,
            }

    train_summary_writer = tensorboard.FileWriter(
        logdir=os.path.join(args.save_dir, 'log', 'train'), flush_secs=10)
    valid_summary_writer = tensorboard.FileWriter(
        logdir=os.path.join(args.save_dir, 'log', 'valid'), flush_secs=10)
    tsw, vsw = train_summary_writer, valid_summary_writer

    logging.info('number of train batches: %d' % data.train_num_batch)
    validate_every = data.train_num_batch // 10
    best_vaild_accuacy = 0
    iter_count = 0
    tic = time.time()

    for epoch_num in range(args.max_epoch):
        for batch_iter, train_batch in enumerate(data.train_minibatch_generator()):
            progress = epoch_num + batch_iter / data.train_num_batch 
            iter_count += 1

            train_loss, train_accuracy = train_iter(args, train_batch, **trpack)
            add_scalar_summary(tsw, 'loss', train_loss, iter_count)
            add_scalar_summary(tsw, 'acc', train_accuracy, iter_count)

            if (batch_iter + 1) % (data.train_num_batch // 100) == 0:
                tac = (time.time() - tic) / 60
                print('   %.2f minutes\tprogress: %.2f' % (tac, progress))
            if (batch_iter + 1) % validate_every == 0:
                correct_sum = 0
                for valid_batch in data.dev_minibatch_generator():
                    correct, supplements = eval_iter(args, model, valid_batch)
                    correct_sum += unwrap_scalar_variable(correct)
                valid_accuracy = correct_sum / data.dev_size 
                scheduler.step(valid_accuracy)
                add_scalar_summary(vsw, 'acc', valid_accuracy, iter_count)
                logging.info('Epoch %.2f: valid accuracy = %.4f' % (progress, valid_accuracy))
                if valid_accuracy > best_vaild_accuacy:
                    correct_sum = 0
                    for test_batch in data.test_minibatch_generator():
                        correct, supplements = eval_iter(args, model, test_batch)
                        correct_sum += unwrap_scalar_variable(correct)
                    test_accuracy = correct_sum / data.test_size
                    best_vaild_accuacy = valid_accuracy
                    model_filename = ('model-%.2f-%.4f-%.4f.pkl' % (progress, valid_accuracy, test_accuracy))
                    model_path = os.path.join(args.save_dir, model_filename)
                    torch.save(model.state_dict(), model_path)
                    print('Saved the new best model to %s' % model_path)
Ejemplo n.º 15
0
def train(seed=0,
          dataset='grid',
          samplers=(UniformDatasetSampler, UniformLatentSampler),
          latent_dim=2,
          model_dim=256,
          device='cuda',
          conditional=False,
          learning_rate=2e-4,
          betas=(0.5, 0.9),
          batch_size=256,
          iterations=400,
          n_critic=5,
          objective='gan',
          gp_lambda=10,
          output_dir='results',
          plot=False,
          spec_norm=True):

    experiment_name = [
        seed, dataset, samplers[0].__name__, samplers[1].__name__, latent_dim,
        model_dim, device, conditional, learning_rate, betas[0], betas[1],
        batch_size, iterations, n_critic, objective, gp_lambda, plot, spec_norm
    ]
    experiment_name = '_'.join([str(p) for p in experiment_name])
    results_dir = os.path.join(output_dir, experiment_name)
    network_dir = os.path.join(results_dir, 'networks')
    eval_log = os.path.join(results_dir, 'eval.log')

    os.makedirs(results_dir, exist_ok=True)
    os.makedirs(network_dir, exist_ok=True)

    eval_file = open(eval_log, 'w')

    if plot:
        samples_dir = os.path.join(results_dir, 'samples')
        os.makedirs(samples_dir, exist_ok=True)

    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    data, labels = load_data(dataset)
    data_dim, num_classes = data.shape[1], len(set(labels))

    data_sampler = samplers[0](
        torch.tensor(data).float(),
        torch.tensor(labels).long()) if conditional else samplers[0](
            torch.tensor(data).float())
    noise_sampler = samplers[1](
        latent_dim, labels) if conditional else samplers[1](latent_dim)

    if conditional:
        test_data, test_labels = load_data(dataset, split='test')
        test_dataset = TensorDataset(
            torch.tensor(test_data).to(device).float(),
            torch.tensor(test_labels).to(device).long())
        test_dataloader = DataLoader(test_dataset, batch_size=4096)

        G = Generator(latent_dim + num_classes, model_dim,
                      data_dim).to(device).train().train()
        D = Discriminator(model_dim,
                          data_dim + num_classes,
                          spec_norm=spec_norm).to(device).train()

        C_real = Classifier(model_dim, data_dim,
                            num_classes).to(device).train()
        C_fake = Classifier(model_dim, data_dim,
                            num_classes).to(device).train()
        C_fake.load_state_dict(deepcopy(C_real.state_dict()))

        C_real_optimizer = optim.Adam(C_real.parameters(),
                                      lr=2 * learning_rate)
        C_fake_optimizer = optim.Adam(C_fake.parameters(),
                                      lr=2 * learning_rate)
        C_crit = nn.CrossEntropyLoss()
    else:
        G = Generator(latent_dim, model_dim, data_dim).to(device).train()
        D = Discriminator(model_dim, data_dim,
                          spec_norm=spec_norm).to(device).train()

    D_optimizer = optim.Adam(D.parameters(), lr=learning_rate, betas=betas)
    G_optimizer = optim.Adam(G.parameters(), lr=learning_rate, betas=betas)

    if objective == 'gan':
        fake_target = torch.zeros(batch_size, 1).to(device)
        real_target = torch.ones(batch_size, 1).to(device)
    elif objective == 'wgan':
        grad_target = torch.ones(batch_size, 1).to(device)
    elif objective == 'hinge':
        bound = torch.zeros(batch_size, 1).to(device)
        sub = torch.ones(batch_size, 1).to(device)

    stats = {'D': [], 'G': [], 'C_it': [], 'C_real': [], 'C_fake': []}
    if plot:
        fixed_latent_batch = noise_sampler.get_batch(20000)
        sample_figure = plt.figure(num=0, figsize=(5, 5))
        loss_figure = plt.figure(num=1, figsize=(10, 5))
        if conditional:
            accuracy_figure = plt.figure(num=2, figsize=(10, 5))

    for it in range(iterations + 1):
        # Train Discriminator
        data_batch = data_sampler.get_batch(batch_size)
        latent_batch = noise_sampler.get_batch(batch_size)

        if conditional:
            x_real, y_real = data_batch[0].to(device), data_batch[1].to(device)
            real_sample = torch.cat([x_real, y_real], dim=1)

            z_fake, y_fake = latent_batch[0].to(device), latent_batch[1].to(
                device)
            x_fake = G(torch.cat([z_fake, y_fake], dim=1)).detach()
            fake_sample = torch.cat([x_fake, y_fake], dim=1)

        else:
            x_real = data_batch.to(device)
            real_sample = x_real

            z_fake = latent_batch.to(device)
            x_fake = G(z_fake).detach()
            fake_sample = x_fake

        D.zero_grad()
        real_pred = D(real_sample)
        fake_pred = D(fake_sample)

        if is_recorded(data_sampler):
            data_sampler.record(real_pred.detach().cpu().numpy())

        if is_weighted(data_sampler):
            weights = torch.tensor(
                data_sampler.get_weights()).to(device).float().view(
                    real_pred.shape)
        else:
            weights = torch.ones_like(real_pred).to(device)

        if objective == 'gan':
            D_loss = F.binary_cross_entropy(fake_pred, fake_target).mean() + (
                weights *
                F.binary_cross_entropy(real_pred, real_target)).mean()
            stats['D'].append(D_loss.item())

        elif objective == 'wgan':
            alpha = torch.rand(batch_size,
                               1).expand(real_sample.size()).to(device)
            interpolate = (alpha * real_sample +
                           (1 - alpha) * fake_sample).requires_grad_(True)
            gradients = torch.autograd.grad(outputs=D(interpolate),
                                            inputs=interpolate,
                                            grad_outputs=grad_target,
                                            create_graph=True,
                                            retain_graph=True,
                                            only_inputs=True)[0]

            gradient_penalty = (gradients.norm(2, dim=1) -
                                1).pow(2).mean() * gp_lambda

            D_loss = fake_pred.mean() - (real_pred * weights).mean()
            stats['D'].append(-D_loss.item())
            D_loss += gradient_penalty

        elif objective == 'hinge':
            D_loss = -(torch.min(real_pred - sub, bound) *
                       weights).mean() - torch.min(-fake_pred - sub,
                                                   bound).mean()
            stats['D'].append(D_loss.item())

        D_loss.backward()
        D_optimizer.step()

        # Train Generator
        if it % n_critic == 0:
            G.zero_grad()

            latent_batch = noise_sampler.get_batch(batch_size)

            if conditional:
                z_fake, y_fake = latent_batch[0].to(
                    device), latent_batch[1].to(device)
                x_fake = G(torch.cat([z_fake, y_fake], dim=1))
                fake_pred = D(torch.cat([x_fake, y_fake], dim=1))
            else:
                z_fake = latent_batch.to(device)
                x_fake = G(z_fake)
                fake_pred = D(x_fake)

            if objective == 'gan':
                G_loss = F.binary_cross_entropy(fake_pred, real_target).mean()
                stats['G'].extend([G_loss.item()] * n_critic)
            elif objective == 'wgan':
                G_loss = -fake_pred.mean()
                stats['G'].extend([-G_loss.item()] * n_critic)
            elif objective == 'hinge':
                G_loss = -fake_pred.mean()
                stats['G'].extend([-G_loss.item()] * n_critic)

            G_loss.backward()
            G_optimizer.step()

        if conditional:
            # Train fake classifier
            C_fake.train()

            C_fake.zero_grad()
            C_fake_loss = C_crit(C_fake(x_fake.detach()), y_fake.argmax(1))
            C_fake_loss.backward()
            C_fake_optimizer.step()

            # Train real classifier
            C_real.train()

            C_real.zero_grad()
            C_real_loss = C_crit(C_real(x_real), y_real.argmax(1))
            C_real_loss.backward()
            C_real_optimizer.step()

        if it % 5 == 0:
            C_real.eval()
            C_fake.eval()
            real_correct, fake_correct, total = 0.0, 0.0, 0.0
            for idx, (sample, label) in enumerate(test_dataloader):
                real_correct += (
                    C_real(sample).argmax(1).view(-1) == label).sum()
                fake_correct += (
                    C_fake(sample).argmax(1).view(-1) == label).sum()
                total += sample.shape[0]

            stats['C_it'].append(it)
            stats['C_real'].append(real_correct.item() / total)
            stats['C_fake'].append(fake_correct.item() / total)

            line = f"{it}\t{stats['D'][-1]:.3f}\t{stats['G'][-1]:.3f}"
            if conditional:
                line += f"\t{stats['C_real'][-1]*100:.3f}\t{stats['C_fake'][-1]*100:.3f}"

            print(line, eval_file)

            if plot:
                if conditional:
                    z_fake, y_fake = fixed_latent_batch[0].to(
                        device), fixed_latent_batch[1].to(device)
                    x_fake = G(torch.cat([z_fake, y_fake], dim=1))
                else:
                    z_fake = fixed_latent_batch.to(device)
                    x_fake = G(z_fake)

                generated = x_fake.detach().cpu().numpy()

                plt.figure(0)
                plt.clf()
                plt.scatter(generated[:, 0],
                            generated[:, 1],
                            marker='.',
                            color=(0, 1, 0, 0.01))
                plt.axis('equal')
                plt.xlim(-1, 1)
                plt.ylim(-1, 1)
                plt.savefig(os.path.join(samples_dir, f'{it}.png'))

                plt.figure(1)
                plt.clf()
                plt.plot(stats['G'], label='Generator')
                plt.plot(stats['D'], label='Discriminator')
                plt.legend()
                plt.savefig(os.path.join(results_dir, 'loss.png'))

                if conditional:
                    plt.figure(2)
                    plt.clf()
                    plt.plot(stats['C_it'], stats['C_real'], label='Real')
                    plt.plot(stats['C_it'], stats['C_fake'], label='Fake')
                    plt.legend()
                    plt.savefig(os.path.join(results_dir, 'accuracy.png'))

    save_model(G, os.path.join(network_dir, 'G_trained.pth'))
    save_model(D, os.path.join(network_dir, 'D_trained.pth'))
    save_stats(stats, os.path.join(results_dir, 'stats.pth'))
    if conditional:
        save_model(C_real, os.path.join(network_dir, 'C_real_trained.pth'))
        save_model(C_fake, os.path.join(network_dir, 'C_fake_trained.pth'))
    eval_file.close()