예제 #1
0
class TrainLoop(object):
    def __init__(self,
                 models_dict,
                 optimizer_task,
                 source_loader,
                 test_source_loader,
                 target_loader,
                 nadir_slack,
                 alpha,
                 patience,
                 factor,
                 label_smoothing,
                 warmup_its,
                 lr_threshold,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 cuda=True,
                 logging=False,
                 ablation='no',
                 train_mode='hv'):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt_task = os.path.join(
            self.checkpoint_path, 'task' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'task_checkpoint_{}ep.pt')
        self.save_epoch_fmt_domain = os.path.join(
            self.checkpoint_path, 'Domain_{}' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'Domain_{}.pt')

        self.cuda_mode = cuda
        self.feature_extractor = models_dict['feature_extractor']
        self.task_classifier = models_dict['task_classifier']
        self.domain_discriminator_list = models_dict[
            'domain_discriminator_list']
        self.optimizer_task = optimizer_task
        self.source_loader = source_loader
        self.test_source_loader = test_source_loader
        self.target_loader = target_loader
        self.history = {
            'loss_task': [],
            'hypervolume': [],
            'loss_domain': [],
            'accuracy_source': [],
            'accuracy_target': []
        }
        self.cur_epoch = 0
        self.total_iter = 0
        self.nadir_slack = nadir_slack
        self.alpha = alpha
        self.ablation = ablation
        self.train_mode = train_mode
        self.device = next(self.feature_extractor.parameters()).device

        its_per_epoch = len(source_loader.dataset) // (
            source_loader.batch_size) + 1 if len(source_loader.dataset) % (
                source_loader.batch_size) > 0 else len(
                    source_loader.dataset) // (source_loader.batch_size)
        patience = patience * (1 + its_per_epoch)
        self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_task,
            factor=factor,
            patience=patience,
            verbose=True if verbose > 0 else False,
            threshold=lr_threshold,
            min_lr=1e-7)
        self.after_scheduler_disc_list = [
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                disc.optimizer,
                factor=factor,
                patience=patience,
                verbose=True if verbose > 0 else False,
                threshold=lr_threshold,
                min_lr=1e-7) for disc in self.domain_discriminator_list
        ]
        self.verbose = verbose
        self.save_cp = save_cp

        self.scheduler_task = GradualWarmupScheduler(
            self.optimizer_task,
            total_epoch=warmup_its,
            after_scheduler=self.after_scheduler_task)
        self.scheduler_disc_list = [
            GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer,
                                   total_epoch=warmup_its,
                                   after_scheduler=sch_disc)
            for i, sch_disc in enumerate(self.after_scheduler_disc_list)
        ]

        if checkpoint_epoch is not None:
            self.load_checkpoint(checkpoint_epoch)

        self.logging = logging
        if self.logging:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter()

        if label_smoothing > 0.0:
            self.ce_criterion = LabelSmoothingLoss(label_smoothing,
                                                   lbl_set_size=7)
        else:
            self.ce_criterion = torch.nn.CrossEntropyLoss(
            )  #torch.nn.NLLLoss()#

        #loss_domain_discriminator
        weight = torch.tensor([2.0 / 3.0, 1.0 / 3.0]).to(self.device)
        #d_cr=torch.nn.CrossEntropyLoss(weight=weight)
        self.d_cr = torch.nn.NLLLoss(weight=weight)

    #### Edit####
    def adjust_learning_rate(self,
                             optimizer,
                             epoch=1,
                             every_n=700,
                             In_lr=0.01):
        """Sets the learning rate to the initial LR decayed by 10 every n epoch epochs"""
        every_n_epoch = every_n  #n_epoch/n_step
        lr = In_lr * (0.1**(epoch // every_n_epoch))
        for param_group in optimizer.param_groups:
            param_group['lr'] = lr

        return lr

    ######
    def train(self, n_epochs=1, save_every=1):
        '''
		##### Edit ####
		# init necessary objects
		num_steps = n_epochs #* (len(train_loader.dataset) / train_loader.batch_size)
		#yd = Variable(torch.from_numpy(np.hstack([np.repeat(1, int(batch_size / 2)), np.repeat(0, int(batch_size / 2))]).reshape(50, 1)))
		j = 0
		lambd=0
		max_lambd= 1 - self.alpha
		
		max_lr_t = list(self.optimizer_task.param_groups)[-1]['initial_lr']
		
		max_lr_d = list(self.domain_discriminator_list[0].optimizer.param_groups)[-1]['initial_lr']
		#0.005
		#lr_2=0
		#lr=In_lr
		#Domain_Classifier.set_lambda(lambd)
		#print(num_steps)

		'''
        max_lr_t = list(self.optimizer_task.param_groups)[-1]['initial_lr']

        max_lr_d = list(self.domain_discriminator_list[0].optimizer.
                        param_groups)[-1]['initial_lr']

        ###############

        while self.cur_epoch < n_epochs:

            cur_loss_task = 0
            cur_hypervolume = 0
            cur_loss_total = 0

            source_iter = tqdm(enumerate(self.source_loader), disable=False)
            '''
			## Edit ##
			## update lambda and learning rate as suggested in the paper
			p = float(j) / num_steps
			lambd = max_lambd * round(2. / (1. + np.exp(-10. * p)) - 1, 3) #I changed 10 to 7
			#lambd=lambd
			
			lr_t = max_lr_t / (1. + 10 * p)**0.75 # 
			lr_d = max_lr_d / (1. + 10 * p)**0.75 #
			
			lr_t= self.adjust_learning_rate(self.optimizer_task,In_lr=lr_t)
			for _, disc in enumerate(self.domain_discriminator_list):
				lr_d= self.adjust_learning_rate(disc.optimizer,In_lr=lr_d)
				
			self.alpha= 1 - lambd
			#print('Alpha = {}' .format(self.alpha))
			j += 1
			'''
            lr_t = max_lr_t  #/ (1. + 10 * p)**0.75 #
            lr_d = max_lr_d  #/ (1. + 10 * p)**0.75 #
            '''
			lr_t= self.adjust_learning_rate(self.optimizer_task, epoch= self.cur_epoch, every_n=0.8 * n_epochs, In_lr=lr_t)
			for _, disc in enumerate(self.domain_discriminator_list):
				lr_d= self.adjust_learning_rate(disc.optimizer, epoch= self.cur_epoch, every_n=0.8 * n_epochs, In_lr=lr_d)
			'''
            ##########

            print(
                'Epoch {}/{} | Alpha = {:1.3} | Lr_task = {:1.4} | Lr_dis = {:1.4} '
                .format(self.cur_epoch + 1, n_epochs, self.alpha, lr_t, lr_d))

            for t, batch in source_iter:
                if self.ablation == 'all':
                    cur_losses = self.train_step_ablation_all(batch)
                else:
                    cur_losses = self.train_step(batch)

                self.scheduler_task.step(epoch=self.total_iter,
                                         metrics=1. -
                                         self.history['accuracy_source'][-1]
                                         if self.cur_epoch > 0 else np.inf)
                for sched in self.scheduler_disc_list:
                    sched.step(epoch=self.total_iter,
                               metrics=1. - self.history['accuracy_source'][-1]
                               if self.cur_epoch > 0 else np.inf)

                cur_loss_task += cur_losses[0]
                cur_hypervolume += cur_losses[1]
                cur_loss_total += cur_losses[2]
                self.total_iter += 1

                if self.logging:
                    self.writer.add_scalar('train/task_loss', cur_losses[0],
                                           self.total_iter)
                    self.writer.add_scalar('train/hypervolume_loss',
                                           cur_losses[1], self.total_iter)
                    self.writer.add_scalar('train/total_loss', cur_losses[2],
                                           self.total_iter)

            self.history['loss_task'].append(cur_loss_task / (t + 1))
            self.history['hypervolume'].append(cur_hypervolume / (t + 1))

            print('Current task loss: {}.'.format(cur_loss_task / (t + 1)))
            print('Current hypervolume: {}.'.format(cur_hypervolume / (t + 1)))

            self.history['accuracy_source'].append(
                test(self.test_source_loader,
                     self.feature_extractor,
                     self.task_classifier,
                     self.domain_discriminator_list,
                     self.device,
                     source_target='source',
                     epoch=self.cur_epoch,
                     tb_writer=self.writer if self.logging else None))
            self.history['accuracy_target'].append(
                test(self.target_loader,
                     self.feature_extractor,
                     self.task_classifier,
                     self.domain_discriminator_list,
                     self.device,
                     source_target='target',
                     epoch=self.cur_epoch,
                     tb_writer=self.writer if self.logging else None))

            idx = np.argmax(self.history['accuracy_source'])

            print(
                'Valid. on SOURCE data - Current acc., best acc., best acc target, and epoch: {:0.4f}, {:0.4f}, {:0.4f}, {}'
                .format(self.history['accuracy_source'][-1],
                        np.max(self.history['accuracy_source']),
                        self.history['accuracy_target'][idx],
                        1 + np.argmax(self.history['accuracy_source'])))
            print(
                'Valid. on TARGET data - Current acc., best acc., and epoch: {:0.4f}, {:0.4f}, {}'
                .format(self.history['accuracy_target'][-1],
                        np.max(self.history['accuracy_target']),
                        1 + np.argmax(self.history['accuracy_target'])))

            if self.logging:
                self.writer.add_scalar(
                    'misc/LR-task', self.optimizer_task.param_groups[0]['lr'],
                    self.total_iter)
                for i, disc in enumerate(self.domain_discriminator_list):
                    self.writer.add_scalar(
                        'misc/LR-disc{}'.format(i),
                        disc.optimizer.param_groups[0]['lr'], self.total_iter)

            self.cur_epoch += 1

            if self.save_cp and (
                    self.cur_epoch % save_every == 0
                    or self.history['accuracy_target'][-1] >
                    np.max([-np.inf] + self.history['accuracy_target'][:-1])):
                self.checkpointing()

        if self.logging:
            self.writer.close()

        idx_final = np.argmax(self.history['accuracy_source'])
        idx_loss = np.argmin(self.history['loss_task'])
        print('min loss task = {} and corresponding target accuracy is = {}'.
              format(idx_loss + 1, self.history['accuracy_target'][idx_loss]))

        return np.max(
            self.history['accuracy_target']
        ), self.history['accuracy_target'][idx], self.history[
            'accuracy_target'][idx_loss], self.history['accuracy_target'][-1]

    def train_step(self, batch):
        self.feature_extractor.train()
        self.task_classifier.train()
        for disc in self.domain_discriminator_list:
            disc = disc.train()

        x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, y_domain_1, y_domain_2, y_domain_3 = batch

        x = torch.cat((x_1, x_2, x_3), dim=0)
        y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0)
        y_domain = torch.cat((y_domain_1, y_domain_2, y_domain_3), dim=0)

        if self.cuda_mode:
            x = x.to(self.device)
            y_task = y_task.to(self.device)

        # COMPUTING FEATURES
        features = self.feature_extractor.forward(x)
        features_ = features.detach()

        # DOMAIN DISCRIMINATORS (First)
        for i, disc in enumerate(self.domain_discriminator_list):
            y_predict = disc.forward(features_).squeeze()

            curr_y_domain = torch.where(y_domain == i,
                                        torch.ones(y_domain.size(0)),
                                        torch.zeros(y_domain.size(0)))
            curr_y_domain.type_as(y_domain)
            #print(y_domain.shape,curr_y_domain.shape,y_predict.shape)
            #print(sum(curr_y_domain))
            curr_y_domain = curr_y_domain.long()
            if self.cuda_mode:
                curr_y_domain = curr_y_domain.long().to(self.device)

            #loss_domain_discriminator = F.binary_cross_entropy_with_logits(y_predict, curr_y_domain)
            #weight = torch.tensor([2.0/3.0, 1.0/3.0]).to(self.device)
            #d_cr=torch.nn.CrossEntropyLoss(weight=weight)
            #d_cr=torch.nn.NLLLoss(weight=weight)
            loss_domain_discriminator = self.d_cr(y_predict, curr_y_domain)
            #print(loss_domain_discriminator)

            if self.logging:
                self.writer.add_scalar('train/D{}_loss'.format(i),
                                       loss_domain_discriminator,
                                       self.total_iter)

            disc.optimizer.zero_grad()
            loss_domain_discriminator.backward()
            disc.optimizer.step()

        # UPDATE TASK CLASSIFIER AND FEATURE EXTRACTOR
        task_out = self.task_classifier.forward(features)

        loss_domain_disc_list = []
        loss_domain_disc_list_float = []
        for i, disc in enumerate(self.domain_discriminator_list):
            y_predict = disc.forward(features).squeeze()
            curr_y_domain = torch.where(y_domain == i,
                                        torch.zeros(y_domain.size(0)),
                                        torch.ones(y_domain.size(0)))

            curr_y_domain = curr_y_domain.long()
            if self.cuda_mode:
                curr_y_domain = curr_y_domain.long().to(self.device)

            #loss_domain_disc_list.append(F.binary_cross_entropy_with_logits(y_predict, curr_y_domain))
            #y_predict= y_predict.long().to(self.device)
            loss_domain_disc_list.append(self.d_cr(y_predict, curr_y_domain))

            loss_domain_disc_list_float.append(
                loss_domain_disc_list[-1].detach().item())

        if self.train_mode == 'hv':
            self.update_nadir_point(loss_domain_disc_list_float)

        hypervolume = 0
        for loss in loss_domain_disc_list:
            if self.train_mode == 'hv':
                hypervolume -= torch.log(self.nadir - loss + 1e-6)
                #hypervolume -= torch.log(loss)

            elif self.train_mode == 'avg':
                hypervolume -= loss

        task_loss = self.ce_criterion(task_out, y_task)
        loss_total = self.alpha * task_loss + (
            1 - self.alpha) * hypervolume / len(loss_domain_disc_list)

        self.optimizer_task.zero_grad()
        loss_total.backward()
        self.optimizer_task.step()

        losses_return = task_loss.item(), hypervolume.item(), loss_total.item()
        return losses_return

    def train_step_ablation_all(self, batch):

        self.feature_extractor.train()
        self.task_classifier.train()
        for disc in self.domain_discriminator_list:
            disc = disc.train()

        x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, _, _, _ = batch

        x = torch.cat((x_1, x_2, x_3), dim=0)
        y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0)

        if self.cuda_mode:
            x = x.to(self.device)
            y_task = y_task.to(self.device)

        # COMPUTING FEATURES
        features = self.feature_extractor.forward(x)
        task_out = self.task_classifier.forward(features)
        task_loss = torch.nn.CrossEntropyLoss()(task_out, y_task)

        self.optimizer_task.zero_grad()
        task_loss.backward()
        self.optimizer_task.step()

        losses_return = task_loss.item(), 0

        return losses_return

    def checkpointing(self):
        if self.verbose > 0:
            print(' ')
            print('Checkpointing...')

        ckpt = {
            'feature_extractor_state': self.feature_extractor.state_dict(),
            'task_classifier_state': self.task_classifier.state_dict(),
            'optimizer_task_state': self.optimizer_task.state_dict(),
            'scheduler_task_state': self.scheduler_task.state_dict(),
            'history': self.history,
            'cur_epoch': self.cur_epoch
        }
        torch.save(ckpt, self.save_epoch_fmt_task.format(self.cur_epoch))

        for i, disc in enumerate(self.domain_discriminator_list):
            ckpt = {
                'model_state': disc.state_dict(),
                'optimizer_disc_state': disc.optimizer.state_dict(),
                'scheduler_disc_state':
                self.scheduler_disc_list[i].state_dict()
            }
            torch.save(ckpt, self.save_epoch_fmt_domain.format(i + 1))

    def load_checkpoint(self, epoch):
        ckpt = self.save_epoch_fmt_task.format(epoch)

        if os.path.isfile(ckpt):

            ckpt = torch.load(ckpt)
            # Load model state
            self.feature_extractor.load_state_dict(
                ckpt['feature_extractor_state'])
            self.task_classifier.load_state_dict(ckpt['task_classifier_state'])
            self.domain_classifier.load_state_dict(
                ckpt['domain_classifier_state'])
            # Load optimizer state
            self.optimizer.load_state_dict(ckpt['optimizer_task_state'])
            # Load scheduler state
            self.scheduler_task.load_state_dict(ckpt['scheduler_task_state'])
            # Load history
            self.history = ckpt['history']
            self.cur_epoch = ckpt['cur_epoch']

            for i, disc in enumerate(self.domain_discriminator_list):
                ckpt = torch.load(self.save_epoch_fmt_domain.format(i + 1))
                disc.load_state_dict(ckpt['model_state'])
                disc.optimizer.load_state_dict(ckpt['optimizer_disc_state'])
                self.scheduler_disc_list[i].load_state_dict(
                    ckpt['scheduler_disc_state'])

        else:
            print('No checkpoint found at: {}'.format(ckpt))

    def print_grad_norms(self, model):
        norm = 0.0
        for params in list(
                filter(lambda p: p.grad is not None, model.parameters())):
            norm += params.grad.norm(2).item()
        print('Sum of grads norms: {}'.format(norm))

    def update_nadir_point(self, losses_list):
        self.nadir = float(np.max(losses_list) * self.nadir_slack + 1e-8)
예제 #2
0
def train(args, logger):
    task_time = time.strftime("%Y-%m-%d %H:%M", time.localtime())
    Path("./saved_models/").mkdir(parents=True, exist_ok=True)
    Path("./pretrained_models/").mkdir(parents=True, exist_ok=True)
    MODEL_SAVE_PATH = './saved_models/'
    Pretrained_MODEL_PATH = './pretrained_models/'
    get_model_name = lambda part: f'{part}-{args.data}-{args.tasks}-{args.prefix}.pth'
    get_pretrain_model_name = lambda part: f'{part}-{args.data}-LP-{args.prefix}.pth'
    device_string = 'cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu >=0 else 'cpu'
    print('Model trainging with '+device_string)
    device = torch.device(device_string)
    


    g = load_graphs(f"./data/{args.data}.dgl")[0][0]
    
    efeat_dim = g.edata['feat'].shape[1]
    nfeat_dim = efeat_dim


    train_loader, val_loader, test_loader, num_val_samples, num_test_samples = dataloader(args, g)


    encoder = Encoder(args, nfeat_dim, n_head=args.n_head, dropout=args.dropout).to(device)
    decoder = Decoder(args, nfeat_dim).to(device)
    msg2mail = Msg2Mail(args, nfeat_dim)
    fraud_sampler = frauder_sampler(g)

    optimizer = torch.optim.Adam(list(encoder.parameters()) + list(decoder.parameters()), lr=args.lr, weight_decay=args.weight_decay)
    scheduler_lr = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=40)
    if args.warmup:
        scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=3, after_scheduler=scheduler_lr)
        optimizer.zero_grad()
        optimizer.step()
    loss_fcn = torch.nn.BCEWithLogitsLoss()

    loss_fcn = loss_fcn.to(device)

    early_stopper = EarlyStopMonitor(logger=logger, max_round=args.patience, higher_better=True)

    if args.pretrain:
        logger.info(f'Loading the linkpred pretrained attention based encoder model')
        encoder.load_state_dict(torch.load(Pretrained_MODEL_PATH+get_pretrain_model_name('Encoder')))

    for epoch in range(args.n_epoch):
        # reset node state
        g.ndata['mail'] = torch.zeros((g.num_nodes(), args.n_mail, nfeat_dim+2), dtype=torch.float32) 
        g.ndata['feat'] = torch.zeros((g.num_nodes(), nfeat_dim), dtype=torch.float32) # init as zero, people can init it using others.
        g.ndata['last_update'] = torch.zeros((g.num_nodes()), dtype=torch.float32) 
        encoder.train()
        decoder.train()
        start_epoch = time.time()
        m_loss = []
        logger.info('start {} epoch, current optim lr is {}'.format(epoch, optimizer.param_groups[0]['lr']))
        for batch_idx, (input_nodes, pos_graph, neg_graph, blocks, frontier, current_ts) in enumerate(train_loader):
            

            pos_graph = pos_graph.to(device)
            neg_graph = neg_graph.to(device) if neg_graph is not None else None
            

            if not args.no_time or not args.no_pos:
                current_ts, pos_ts, num_pos_nodes = get_current_ts(args, pos_graph, neg_graph)
                pos_graph.ndata['ts'] = current_ts
            else:
                current_ts, pos_ts, num_pos_nodes = None, None, None
            
            _ = dgl.add_reverse_edges(neg_graph) if neg_graph is not None else None
            emb, _ = encoder(dgl.add_reverse_edges(pos_graph), _, num_pos_nodes)
            if batch_idx != 0:
                if 'LP' not in args.tasks and args.balance:
                    neg_graph = fraud_sampler.sample_fraud_event(g, args.bs//5, current_ts.max().cpu()).to(device)
                logits, labels = decoder(emb, pos_graph, neg_graph)

                loss = loss_fcn(logits, labels)
                
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                m_loss.append(loss.item())


            # MSG Passing
            with torch.no_grad():
                mail = msg2mail.gen_mail(args, emb, input_nodes, pos_graph, frontier, 'train')

                if not args.no_time:
                    g.ndata['last_update'][pos_graph.ndata[dgl.NID][:num_pos_nodes]] = pos_ts.to('cpu')
                g.ndata['feat'][pos_graph.ndata[dgl.NID]] = emb.to('cpu')
                g.ndata['mail'][input_nodes] = mail
            if batch_idx % 100 == 1:
                gpu_mem = torch.cuda.max_memory_allocated() / 1.074e9 if torch.cuda.is_available() and args.gpu >= 0 else 0
                torch.cuda.empty_cache()
                mem_perc = psutil.virtual_memory().percent
                cpu_perc = psutil.cpu_percent(interval=None)
                output_string = f'Epoch {epoch} | Step {batch_idx}/{len(train_loader)} | CPU {cpu_perc:.1f}% | Sys Mem {mem_perc:.1f}% | GPU Mem {gpu_mem:.4f}GB '
                
                output_string += f'| {args.tasks} Loss {np.mean(m_loss):.4f}'

                logger.info(output_string)

        total_epoch_time = time.time() - start_epoch
        logger.info(' training epoch: {} took {:.4f}s'.format(epoch, total_epoch_time))
        val_ap, val_auc, val_acc, val_loss = eval_epoch(args, logger, g, val_loader, encoder, decoder, msg2mail, loss_fcn, device, num_val_samples)
        logger.info('Val {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, val_ap, val_auc, val_acc, val_loss))

        if args.warmup:
            scheduler_warmup.step(epoch)
        else:
            scheduler_lr.step()

        early_stopper_metric = val_ap if 'LP' in args.tasks else val_auc

        if early_stopper.early_stop_check(early_stopper_metric):
            logger.info('No improvement over {} epochs, stop training'.format(early_stopper.max_round))
            logger.info(f'Loading the best model at epoch {early_stopper.best_epoch}')
            encoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Encoder')))
            decoder.load_state_dict(torch.load(MODEL_SAVE_PATH+get_model_name('Decoder')))

            test_result = [early_stopper.best_ap, early_stopper.best_auc, early_stopper.best_acc, early_stopper.best_loss]
            break

        test_ap, test_auc, test_acc, test_loss = eval_epoch(args, logger, g, test_loader, encoder, decoder, msg2mail, loss_fcn, device, num_test_samples)
        logger.info('Test {} Task | ap: {:.4f} | auc: {:.4f} | acc: {:.4f} | Loss: {:.4f}'.format(args.tasks, test_ap, test_auc, test_acc, test_loss))
        test_result = [test_ap, test_auc, test_acc, test_loss]

        if early_stopper.best_epoch == epoch: 
            early_stopper.best_ap = test_ap
            early_stopper.best_auc = test_auc
            early_stopper.best_acc = test_acc
            early_stopper.best_loss = test_loss
            logger.info(f'Saving the best model at epoch {early_stopper.best_epoch}')
            torch.save(encoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Encoder'))
            torch.save(decoder.state_dict(), MODEL_SAVE_PATH+get_model_name('Decoder'))
예제 #3
0
    def __init__(self,
                 models_dict,
                 optimizer_task,
                 source_loader,
                 test_source_loader,
                 target_loader,
                 nadir_slack,
                 alpha,
                 patience,
                 factor,
                 label_smoothing,
                 warmup_its,
                 lr_threshold,
                 verbose=-1,
                 cp_name=None,
                 save_cp=False,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 cuda=True,
                 logging=False,
                 ablation='no',
                 train_mode='hv'):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                os.mkdir(self.checkpoint_path)

        self.save_epoch_fmt_task = os.path.join(
            self.checkpoint_path, 'task' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'task_checkpoint_{}ep.pt')
        self.save_epoch_fmt_domain = os.path.join(
            self.checkpoint_path, 'Domain_{}' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'Domain_{}.pt')

        self.cuda_mode = cuda
        self.feature_extractor = models_dict['feature_extractor']
        self.task_classifier = models_dict['task_classifier']
        self.domain_discriminator_list = models_dict[
            'domain_discriminator_list']
        self.optimizer_task = optimizer_task
        self.source_loader = source_loader
        self.test_source_loader = test_source_loader
        self.target_loader = target_loader
        self.history = {
            'loss_task': [],
            'hypervolume': [],
            'loss_domain': [],
            'accuracy_source': [],
            'accuracy_target': []
        }
        self.cur_epoch = 0
        self.total_iter = 0
        self.nadir_slack = nadir_slack
        self.alpha = alpha
        self.ablation = ablation
        self.train_mode = train_mode
        self.device = next(self.feature_extractor.parameters()).device

        its_per_epoch = len(source_loader.dataset) // (
            source_loader.batch_size) + 1 if len(source_loader.dataset) % (
                source_loader.batch_size) > 0 else len(
                    source_loader.dataset) // (source_loader.batch_size)
        patience = patience * (1 + its_per_epoch)
        self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_task,
            factor=factor,
            patience=patience,
            verbose=True if verbose > 0 else False,
            threshold=lr_threshold,
            min_lr=1e-7)
        self.after_scheduler_disc_list = [
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                disc.optimizer,
                factor=factor,
                patience=patience,
                verbose=True if verbose > 0 else False,
                threshold=lr_threshold,
                min_lr=1e-7) for disc in self.domain_discriminator_list
        ]
        self.verbose = verbose
        self.save_cp = save_cp

        self.scheduler_task = GradualWarmupScheduler(
            self.optimizer_task,
            total_epoch=warmup_its,
            after_scheduler=self.after_scheduler_task)
        self.scheduler_disc_list = [
            GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer,
                                   total_epoch=warmup_its,
                                   after_scheduler=sch_disc)
            for i, sch_disc in enumerate(self.after_scheduler_disc_list)
        ]

        if checkpoint_epoch is not None:
            self.load_checkpoint(checkpoint_epoch)

        self.logging = logging
        if self.logging:
            from torch.utils.tensorboard import SummaryWriter
            self.writer = SummaryWriter()

        if label_smoothing > 0.0:
            self.ce_criterion = LabelSmoothingLoss(label_smoothing,
                                                   lbl_set_size=7)
        else:
            self.ce_criterion = torch.nn.CrossEntropyLoss(
            )  #torch.nn.NLLLoss()#

        #loss_domain_discriminator
        weight = torch.tensor([2.0 / 3.0, 1.0 / 3.0]).to(self.device)
        #d_cr=torch.nn.CrossEntropyLoss(weight=weight)
        self.d_cr = torch.nn.NLLLoss(weight=weight)
예제 #4
0
    params_to_finetune = list(im_im_scorer_model.parameters())
    models_to_save.append(im_im_scorer_model)

    # optimizer
    optfunc = {
        'adam': optim.Adam,
        'rmsprop': optim.RMSprop,
        'sgd': optim.SGD
    }[args.optimizer]
    pretrain_optimizer = optfunc(params_to_pretrain, lr=args.pt_lr)
    finetune_optimizer = optfunc(params_to_finetune, lr=args.ft_lr)
    # models_to_save.append(optimizer)
    after_scheduler = optim.lr_scheduler.StepLR(pretrain_optimizer, 4000, 0.5)
    # after_scheduler = optim.lr_scheduler.CosineAnnealingLR(pretrain_optimizer, T_max=500*args.pt_epochs-1000)
    pt_scheduler = GradualWarmupScheduler(pretrain_optimizer,
                                          1.0,
                                          total_epoch=1000,
                                          after_scheduler=after_scheduler)
    ft_scheduler = optim.lr_scheduler.StepLR(finetune_optimizer,
                                             80 * args.ft_epochs, 0.5)
    print(sum([p.numel() for p in params_to_pretrain]))
    print(sum([p.numel() for p in params_to_finetune]))

    if args.load_checkpoint and os.path.exists(
            os.path.join(args.exp_dir, 'checkpoint.pth.tar')):
        ckpt_path = os.path.join(args.exp_dir, 'checkpoint.pth.tar')
        sds = torch.load(ckpt_path, map_location=device)
        for m in models_to_save:
            if (not isinstance(m, TransformerAgg)) and (not isinstance(
                    m, RelationNetAgg)):
                print(m.load_state_dict(sds[repr(m)]))
        print("loaded checkpoint")
예제 #5
0
class TrainLoop(object):
    def __init__(self,
                 args,
                 models_dict,
                 optimizer_task,
                 source_loader,
                 test_source_loader,
                 target_loader,
                 nadir_slack,
                 alpha,
                 patience,
                 factor,
                 label_smoothing,
                 warmup_its,
                 lr_threshold,
                 verbose=-1,
                 cp_name=None,
                 save_cp=True,
                 checkpoint_path=None,
                 checkpoint_epoch=None,
                 cuda=True,
                 logging=False,
                 ablation='no',
                 train_mode='hv'):
        if checkpoint_path is None:
            # Save to current directory
            self.checkpoint_path = os.getcwd()
        else:
            self.checkpoint_path = checkpoint_path
            if not os.path.isdir(self.checkpoint_path):
                try:
                    os.mkdir(self.checkpoint_path)
                except OSError:
                    pass

        self.save_epoch_fmt_task = os.path.join(
            self.checkpoint_path, 'task' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'task_checkpoint_{}ep.pt')
        self.save_epoch_fmt_domain = os.path.join(
            self.checkpoint_path, 'Domain_{}' +
            cp_name) if cp_name else os.path.join(self.checkpoint_path,
                                                  'Domain_{}.pt')

        self.cuda_mode = cuda
        self.feature_extractor = models_dict['feature_extractor']
        self.task_classifier = models_dict['task_classifier']
        self.domain_discriminator_list = models_dict[
            'domain_discriminator_list']
        self.optimizer_task = optimizer_task
        self.source_loader = source_loader
        self.test_source_loader = test_source_loader
        self.target_loader = target_loader
        self.history = {
            'loss_task': [],
            'loss_total': [],
            'loss_domain': [],
            'accuracy_source': [],
            'accuracy_target': [],
            'loss_task_val_source': []
        }
        self.cur_epoch = 0
        self.total_iter = 0
        self.nadir_slack = nadir_slack
        self.alpha = alpha
        self.ablation = ablation
        self.train_mode = train_mode
        self.device = next(self.feature_extractor.parameters()).device
        self.args = args

        its_per_epoch = len(source_loader.dataset) // (
            source_loader.batch_size) + 1 if len(source_loader.dataset) % (
                source_loader.batch_size) > 0 else len(
                    source_loader.dataset) // (source_loader.batch_size)
        patience = patience * (1 + its_per_epoch)
        self.after_scheduler_task = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer_task,
            factor=factor,
            patience=patience,
            verbose=True if verbose > 0 else False,
            threshold=lr_threshold,
            min_lr=1e-7)
        self.after_scheduler_disc_list = [
            torch.optim.lr_scheduler.ReduceLROnPlateau(
                disc.optimizer,
                factor=factor,
                patience=patience,
                verbose=True if verbose > 0 else False,
                threshold=lr_threshold,
                min_lr=1e-7) for disc in self.domain_discriminator_list
        ]
        self.verbose = verbose
        self.save_cp = save_cp

        self.scheduler_task = GradualWarmupScheduler(
            self.optimizer_task,
            total_epoch=warmup_its,
            after_scheduler=self.after_scheduler_task)
        self.scheduler_disc_list = [
            GradualWarmupScheduler(self.domain_discriminator_list[i].optimizer,
                                   total_epoch=warmup_its,
                                   after_scheduler=sch_disc)
            for i, sch_disc in enumerate(self.after_scheduler_disc_list)
        ]

        if checkpoint_epoch is not None:
            self.load_checkpoint(checkpoint_epoch)

        self.logging = logging
        if self.logging:
            from torch.utils.tensorboard import SummaryWriter
            log_path = args.checkpoint_path + 'runs/'
            self.writer = SummaryWriter(log_path)

        if label_smoothing > 0.0:
            self.ce_criterion = LabelSmoothingLoss(label_smoothing,
                                                   lbl_set_size=7)
        else:
            self.ce_criterion = torch.nn.CrossEntropyLoss()

    def train(self, n_epochs=1, save_every=1):

        while self.cur_epoch < n_epochs:

            print('Epoch {}/{}'.format(self.cur_epoch + 1, n_epochs))

            self.cur_loss_task = 0
            self.cur_hypervolume = 0
            self.cur_loss_total = 0

            source_iter = tqdm(enumerate(self.source_loader),
                               total=len(self.source_loader),
                               disable=False)

            for t, batch in source_iter:
                if self.ablation == 'all':
                    cur_losses = self.train_step_ablation_all(batch)
                else:
                    cur_losses = self.train_step(batch)

                self.scheduler_task.step(epoch=self.total_iter,
                                         metrics=1. -
                                         self.history['accuracy_source'][-1]
                                         if self.cur_epoch > 0 else np.inf)
                for sched in self.scheduler_disc_list:
                    sched.step(epoch=self.total_iter,
                               metrics=1. - self.history['accuracy_source'][-1]
                               if self.cur_epoch > 0 else np.inf)

                self.cur_loss_task += cur_losses[0]
                self.cur_hypervolume += cur_losses[1]
                self.cur_loss_total += cur_losses[2]
                self.total_iter += 1

                if self.logging:
                    self.writer.add_scalar('Iteration/task_loss',
                                           cur_losses[0], self.total_iter)
                    self.writer.add_scalar('Iteration/domain_loss',
                                           cur_losses[1], self.total_iter)
                    self.writer.add_scalar('Iteration/total_loss',
                                           cur_losses[2], self.total_iter)

            self.history['loss_task'].append(self.cur_loss_task / (t + 1))
            self.history['loss_domain'].append(self.cur_hypervolume / (t + 1))
            self.history['loss_total'].append(self.cur_loss_total / (t + 1))

            acc_source, loss_task_val_source = test(
                self.test_source_loader,
                self.feature_extractor,
                self.task_classifier,
                self.domain_discriminator_list,
                self.device,
                source_target='source',
                epoch=self.cur_epoch,
                tb_writer=self.writer if self.logging else None)
            acc_target, loss_task_target = test(
                self.target_loader,
                self.feature_extractor,
                self.task_classifier,
                self.domain_discriminator_list,
                self.device,
                source_target='target',
                epoch=self.cur_epoch,
                tb_writer=self.writer if self.logging else None)

            self.history['accuracy_source'].append(acc_source)
            self.history['accuracy_target'].append(acc_target)
            self.history['loss_task_val_source'].append(loss_task_val_source)

            self.source_epoch_best_loss_task = np.argmin(
                self.history['loss_task'])
            self.source_epoch_best_loss_domain = np.argmin(
                self.history['loss_domain'])
            self.source_epoch_best_loss_total = np.argmin(
                self.history['loss_total'])
            self.source_epoch_best_loss_task_val = np.argmin(
                self.history['loss_task_val_source'])
            self.source_epoch_best = np.argmax(self.history['accuracy_source'])
            self.target_epoch_best = np.argmax(self.history['accuracy_target'])

            self.source_best_acc = np.max(self.history['accuracy_source'])
            self.target_best_loss_task = self.history['accuracy_target'][
                self.source_epoch_best_loss_task]
            self.target_best_loss_domain = self.history['accuracy_target'][
                self.source_epoch_best_loss_domain]
            self.target_best_loss_total = self.history['accuracy_target'][
                self.source_epoch_best_loss_total]
            self.target_best_source_acc = self.history['accuracy_target'][
                self.source_epoch_best]
            self.target_best_acc = np.max(self.history['accuracy_target'])
            self.target_best_acc_loss_task_val = self.history[
                'accuracy_target'][self.source_epoch_best_loss_task_val]

            self.print_results()

            if self.logging:
                self.writer.add_scalar(
                    'misc/LR-task', self.optimizer_task.param_groups[0]['lr'],
                    self.total_iter)
                for i, disc in enumerate(self.domain_discriminator_list):
                    self.writer.add_scalar(
                        'misc/LR-disc{}'.format(i),
                        disc.optimizer.param_groups[0]['lr'], self.total_iter)
                self.writer.add_scalar('Epoch/Loss-total',
                                       self.history['loss_total'][-1],
                                       self.cur_epoch)
                self.writer.add_scalar('Epoch/Loss-task',
                                       self.history['loss_task'][-1],
                                       self.cur_epoch)
                self.writer.add_scalar('Epoch/Loss-domain',
                                       self.history['loss_domain'][-1],
                                       self.cur_epoch)
                self.writer.add_scalar(
                    'Epoch/Loss-task-val',
                    self.history['loss_task_val_source'][-1], self.cur_epoch)
                self.writer.add_scalar('Epoch/Acc-Source',
                                       self.history['accuracy_source'][-1],
                                       self.cur_epoch)
                self.writer.add_scalar('Epoch/Acc-target',
                                       self.history['accuracy_target'][-1],
                                       self.cur_epoch)

            self.cur_epoch += 1

            if self.save_cp and (
                    self.cur_epoch % save_every == 0
                    or self.history['accuracy_target'][-1] >
                    np.max([-np.inf] + self.history['accuracy_target'][:-1])):
                self.checkpointing()

        if self.logging:
            self.writer.close()

        results_acc = [
            self.target_best_loss_task, self.target_best_loss_domain,
            self.target_best_loss_total, self.target_best_source_acc,
            self.source_best_acc, self.target_best_acc
        ]
        results_epochs = [
            self.source_epoch_best_loss_task,
            self.source_epoch_best_loss_domain,
            self.source_epoch_best_loss_total, self.source_epoch_best,
            self.target_epoch_best
        ]

        return np.min(
            self.history['loss_task_val_source']), results_acc, results_epochs

    def train_step(self, batch):
        self.feature_extractor.train()
        self.task_classifier.train()
        for disc in self.domain_discriminator_list:
            disc = disc.train()

        x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, y_domain_1, y_domain_2, y_domain_3 = batch

        x = torch.cat((x_1, x_2, x_3), dim=0)
        y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0)
        y_domain = torch.cat((y_domain_1, y_domain_2, y_domain_3), dim=0)

        if self.cuda_mode:
            x = x.to(self.device)
            y_task = y_task.to(self.device)

        # COMPUTING FEATURES
        features = self.feature_extractor.forward(x)
        features_ = features.detach()

        # DOMAIN DISCRIMINATORS
        for i, disc in enumerate(self.domain_discriminator_list):
            y_predict = disc.forward(features_).squeeze()
            curr_y_domain = torch.where(y_domain == i,
                                        torch.ones(y_domain.size(0)),
                                        torch.zeros(y_domain.size(0)))

            if self.cuda_mode:
                curr_y_domain = curr_y_domain.float().to(self.device)

            loss_domain_discriminator = F.binary_cross_entropy_with_logits(
                y_predict, curr_y_domain)

            if self.logging:
                self.writer.add_scalar('train/D{}_loss'.format(i),
                                       loss_domain_discriminator,
                                       self.total_iter)

            disc.optimizer.zero_grad()
            loss_domain_discriminator.backward()
            disc.optimizer.step()

        # UPDATE TASK CLASSIFIER AND FEATURE EXTRACTOR
        task_out = self.task_classifier.forward(features)

        loss_domain_disc_list = []
        loss_domain_disc_list_float = []
        for i, disc in enumerate(self.domain_discriminator_list):
            y_predict = disc.forward(features).squeeze()
            curr_y_domain = torch.where(y_domain == i,
                                        torch.zeros(y_domain.size(0)),
                                        torch.ones(y_domain.size(0)))

            if self.cuda_mode:
                curr_y_domain = curr_y_domain.float().to(self.device)

            loss_domain_disc_list.append(
                F.binary_cross_entropy_with_logits(y_predict, curr_y_domain))
            loss_domain_disc_list_float.append(
                loss_domain_disc_list[-1].detach().item())

        if self.train_mode == 'hv':
            self.update_nadir_point(loss_domain_disc_list_float)

        hypervolume = 0
        for loss in loss_domain_disc_list:
            if self.train_mode == 'hv':
                hypervolume -= torch.log(self.nadir - loss + 1e-6)
            elif self.train_mode == 'avg':
                hypervolume -= loss

        task_loss = self.ce_criterion(task_out, y_task)
        loss_total = self.alpha * task_loss + (
            1 - self.alpha) * hypervolume / len(loss_domain_disc_list)

        self.optimizer_task.zero_grad()
        loss_total.backward()
        self.optimizer_task.step()

        losses_return = task_loss.item(), hypervolume.item(), loss_total.item()
        return losses_return

    def train_step_ablation_all(self, batch):

        self.feature_extractor.train()
        self.task_classifier.train()
        for disc in self.domain_discriminator_list:
            disc = disc.train()

        x_1, x_2, x_3, y_task_1, y_task_2, y_task_3, _, _, _ = batch

        x = torch.cat((x_1, x_2, x_3), dim=0)
        y_task = torch.cat((y_task_1, y_task_2, y_task_3), dim=0)

        if self.cuda_mode:
            x = x.to(self.device)
            y_task = y_task.to(self.device)

        # COMPUTING FEATURES
        features = self.feature_extractor.forward(x)
        task_out = self.task_classifier.forward(features)
        task_loss = torch.nn.CrossEntropyLoss()(task_out, y_task)

        self.optimizer_task.zero_grad()
        task_loss.backward()
        self.optimizer_task.step()

        losses_return = task_loss.item(), 0

        return losses_return

    def checkpointing(self):
        if self.verbose > 0:
            print(' ')
            print('Checkpointing...')

        ckpt = {
            'feature_extractor_state': self.feature_extractor.state_dict(),
            'task_classifier_state': self.task_classifier.state_dict(),
            'optimizer_task_state': self.optimizer_task.state_dict(),
            'scheduler_task_state': self.scheduler_task.state_dict(),
            'history': self.history,
            'cur_epoch': self.cur_epoch
        }
        torch.save(ckpt, self.save_epoch_fmt_task.format(self.cur_epoch))

        for i, disc in enumerate(self.domain_discriminator_list):
            ckpt = {
                'model_state': disc.state_dict(),
                'optimizer_disc_state': disc.optimizer.state_dict(),
                'scheduler_disc_state':
                self.scheduler_disc_list[i].state_dict()
            }
            torch.save(ckpt, self.save_epoch_fmt_domain.format(i + 1))

    def load_checkpoint(self, epoch):
        ckpt = self.save_epoch_fmt_task.format(epoch)

        if os.path.isfile(ckpt):

            ckpt = torch.load(ckpt)
            # Load model state
            self.feature_extractor.load_state_dict(
                ckpt['feature_extractor_state'])
            self.task_classifier.load_state_dict(ckpt['task_classifier_state'])
            self.domain_classifier.load_state_dict(
                ckpt['domain_classifier_state'])
            # Load optimizer state
            self.optimizer.load_state_dict(ckpt['optimizer_task_state'])
            # Load scheduler state
            self.scheduler_task.load_state_dict(ckpt['scheduler_task_state'])
            # Load history
            self.history = ckpt['history']
            self.cur_epoch = ckpt['cur_epoch']

            for i, disc in enumerate(self.domain_discriminator_list):
                ckpt = torch.load(self.save_epoch_fmt_domain.format(i + 1))
                disc.load_state_dict(ckpt['model_state'])
                disc.optimizer.load_state_dict(ckpt['optimizer_disc_state'])
                self.scheduler_disc_list[i].load_state_dict(
                    ckpt['scheduler_disc_state'])

        else:
            print('No checkpoint found at: {}'.format(ckpt))

    def print_grad_norms(self, model):
        norm = 0.0
        for params in list(
                filter(lambda p: p.grad is not None, model.parameters())):
            norm += params.grad.norm(2).item()
        print('Sum of grads norms: {}'.format(norm))

    def update_nadir_point(self, losses_list):
        self.nadir = float(np.max(losses_list) * self.nadir_slack + 1e-8)

    def print_results(self):
        print('Current task loss: {}.'.format(self.history['loss_task'][-1]))
        print('Current hypervolume: {}.'.format(
            self.history['loss_domain'][-1]))
        print('Current total loss: {}.'.format(self.history['loss_total'][-1]))

        print('VALIDATION ON SOURCE DOMAINS - {}, {}, {}'.format(
            self.args.source1, self.args.source2, self.args.source3))
        print('Current, best, and epoch: {:0.4f}, {:0.4f}, {}'.format(
            self.history['accuracy_source'][-1], self.source_best_acc,
            self.source_epoch_best))

        print('VALIDATION ON TARGET DOMAIN - {}'.format(self.args.target))
        print('Current, best, and epoch: {:0.4f}, {:0.4f}, {}'.format(
            self.history['accuracy_target'][-1], self.target_best_acc,
            self.target_epoch_best))

        print('VALIDATION ON TARGET DOMAIN - BEST TOTAL LOSS')
        print('Best and epoch: {:0.4f}, {}'.format(
            self.target_best_loss_total, self.source_epoch_best_loss_total))

        print('VALIDATION ON TARGET DOMAIN - BEST SOURCE VAL ACC')
        print('Best and epoch: {:0.4f}, {}'.format(self.target_best_source_acc,
                                                   self.source_epoch_best))

        print('VALIDATION ON TARGET DOMAIN - BEST TASK LOSS')
        print('Best and epoch: {:0.4f}, {}'.format(
            self.target_best_loss_task, self.source_epoch_best_loss_task))

        print('VALIDATION ON TARGET DOMAIN - BEST DOMAIN DISC LOSS')
        print('Best and epoch: {:0.4f}, {}'.format(
            self.target_best_loss_domain, self.source_epoch_best_loss_domain))

        print('VALIDATION ON TARGET DOMAIN - BEST VAL TASK LOSS')
        print('Best and epoch: {:0.4f}, {}'.format(
            self.target_best_acc_loss_task_val,
            self.source_epoch_best_loss_task_val))
예제 #6
0
import torch

from utils import GradualWarmupScheduler

if __name__ == '__main__':
    v = torch.zeros(10)
    optim = torch.optim.SGD([v], lr=0.01)
    scheduler = GradualWarmupScheduler(optim, multiplier=8, total_epoch=10)

    for epoch in range(1, 20):
        scheduler.step(epoch)

        print(epoch, optim.param_groups[0]['lr'])
예제 #7
0
def run():
    seed_torch(seed=config.seed)
    os.makedirs(config.MODEL_PATH, exist_ok=True)
    setup_logger(config.MODEL_PATH + 'log.txt')
    writer = SummaryWriter(config.MODEL_PATH)

    folds = pd.read_csv(config.fold_csv)
    folds.head()
    if config.tile_stats_csv:
        attention_df = pd.read_csv(config.tile_stats_csv)
        attention_df.head()

    #train val split
    if config.DEBUG:
        folds = folds.sample(
            n=50, random_state=config.seed).reset_index(drop=True).copy()

    logging.info(f"fold: {config.fold}")
    fold = config.fold
    #trn_idx = folds[folds['fold'] != fold].index
    #val_idx = folds[folds['fold'] == fold].index
    trn_idx = folds[folds[f'fold_{fold}'] == 0].index
    val_idx = folds[folds[f'fold_{fold}'] == 1].index

    df_train = folds.loc[trn_idx]
    df_val = folds.loc[val_idx]
    # #------single image------
    if config.strategy == 'stitched':
        train_dataset = PANDADataset(image_folder=config.DATA_PATH,
                                     df=df_train,
                                     image_size=config.IMG_SIZE,
                                     num_tiles=config.num_tiles,
                                     rand=False,
                                     transform=get_transforms(phase='train'),
                                     attention_df=attention_df)
        valid_dataset = PANDADataset(image_folder=config.DATA_PATH,
                                     df=df_val,
                                     image_size=config.IMG_SIZE,
                                     num_tiles=config.num_tiles,
                                     rand=False,
                                     transform=get_transforms(phase='valid'),
                                     attention_df=attention_df)

    #------image tiles------
    else:
        train_dataset = PANDADatasetTiles(
            image_folder=config.DATA_PATH,
            df=df_train,
            image_size=config.IMG_SIZE,
            num_tiles=config.num_tiles,
            transform=get_transforms(phase='train'),
            attention_df=attention_df)
        valid_dataset = PANDADatasetTiles(
            image_folder=config.DATA_PATH,
            df=df_val,
            image_size=config.IMG_SIZE,
            num_tiles=config.num_tiles,
            transform=get_transforms(phase='valid'),
            attention_df=attention_df)

    train_loader = DataLoader(train_dataset,
                              batch_size=config.batch_size,
                              sampler=RandomSampler(train_dataset),
                              num_workers=multiprocessing.cpu_count(),
                              pin_memory=True)
    val_loader = DataLoader(valid_dataset,
                            batch_size=config.batch_size,
                            sampler=SequentialSampler(valid_dataset),
                            num_workers=multiprocessing.cpu_count(),
                            pin_memory=True)

    device = torch.device("cuda")
    #model=EnetNetVLAD(num_clusters=config.num_cluster,num_tiles=config.num_tiles,num_classes=config.num_class,arch=config.backbone)
    #model = EnetV1(backbone=config.backbone, num_classes=config.num_class)
    #------Model use for Generate Tile Weights--------
    #model = EfficientModel(c_out=config.num_class,n_tiles=config.num_tiles,
    #                       tile_size=config.IMG_SIZE,
    #                       name=config.backbone,
    #                       strategy='bag',
    #                       head='attention')
    #--------------------------------------------------
    #model = Regnet(num_classes=config.num_class,ckpt=config.pretrain_model)
    model = RegnetNetVLAD(num_clusters=config.num_cluster,
                          num_tiles=config.num_tiles,
                          num_classes=config.num_class,
                          ckpt=config.pretrain_model)
    model = model.to(device)
    if config.multi_gpu:
        model = torch.nn.DataParallel(model)
    if config.ckpt_path:
        model.load_state_dict(torch.load(config.ckpt_path))
    warmup_factor = 10
    warmup_epo = 1
    optimizer = Adam(model.parameters(), lr=config.lr / warmup_factor)
    scheduler_cosine = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, config.num_epoch - warmup_epo)
    scheduler = GradualWarmupScheduler(optimizer,
                                       multiplier=warmup_factor,
                                       total_epoch=warmup_epo,
                                       after_scheduler=scheduler_cosine)

    if config.apex:
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level="O1",
                                          verbosity=0)

    best_score = 0.
    best_loss = 100.
    if config.model_type == 'reg':
        optimized_rounder = OptimizedRounder()
    optimizer.zero_grad()
    optimizer.step()
    for epoch in range(1, config.num_epoch + 1):
        if scheduler:
            scheduler.step(epoch - 1)
        if config.model_type != 'reg':
            train_fn(train_loader, model, optimizer, device, epoch, writer,
                     df_train)
            metric = eval_fn(val_loader, model, device, epoch, writer, df_val)
        else:
            coefficients = train_fn(train_loader, model, optimizer, device,
                                    epoch, writer, df_train, optimized_rounder)
            metric = eval_fn(val_loader, model, device, epoch, writer, df_val,
                             coefficients)
        score = metric['score']
        val_loss = metric['loss']
        if score > best_score:
            best_score = score
            logging.info(f"Epoch {epoch} - found best score {best_score}")
            save_model(model, config.MODEL_PATH + f"best_kappa_f{fold}.pth")
        if val_loss < best_loss:
            best_loss = val_loss
            logging.info(f"Epoch {epoch} - found best loss {best_loss}")
            save_model(model, config.MODEL_PATH + f"best_loss_f{fold}.pth")