Beispiel #1
0
def train_data_iterator(entities, triples):
    entities_1, entities_2 = entities
    triples_1, triples_2 = triples
    loader_head_1 = DataLoader(TrainDataset(triples_1, entities_1,
                                            config.neg_size, "head-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    loader_tail_1 = DataLoader(TrainDataset(triples_1, entities_1,
                                            config.neg_size, "tail-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    loader_head_2 = DataLoader(TrainDataset(triples_2, entities_2,
                                            config.neg_size, "head-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    loader_tail_2 = DataLoader(TrainDataset(triples_2, entities_2,
                                            config.neg_size, "tail-batch"),
                               batch_size=config.batch_size,
                               shuffle=True,
                               num_workers=max(0, config.cpu_num // 3),
                               collate_fn=TrainDataset.collate_fn)
    return BidirectionalOneShotIterator(loader_head_1, loader_tail_1,
                                        loader_head_2, loader_tail_2)
 def prepareData(self):
     print("Perpare dataloader")
     self.train = TrainDataset(self.dataset)
     self.trainloader = None
     self.valid = EvalDataset(self.dataset)
     self.validloader = DataLoader(self.valid,
                                   batch_size=self.valid.n_triples,
                                   shuffle=False)
Beispiel #3
0
    def train_NoiGAN(trainer):
        trainer.embedding_model.eval()

        st = time.time()
        trainer.positive_triples = trainer.find_positive_triples()
        et = time.time()
        print("take %d s to find positive triples" % (et - st))

        trainer.train_dataset_head = TrainDataset(
            trainer.train_triples, trainer.args.nentity,
            trainer.args.nrelation, trainer.args.negative_sample_size,
            "head-batch")
        trainer.train_dataset_head.triples = trainer.positive_triples
        trainer.train_dataset_tail = TrainDataset(
            trainer.train_triples, trainer.args.nentity,
            trainer.args.nrelation, trainer.args.negative_sample_size,
            "tail-batch")
        trainer.train_dataset_tail.triples = trainer.positive_triples
        trainer.train_dataloader_head = DataLoader(
            trainer.train_dataset_head,
            batch_size=128,
            shuffle=True,
            num_workers=5,
            collate_fn=TrainDataset.collate_fn)
        trainer.train_dataloader_tail = DataLoader(
            trainer.train_dataset_tail,
            batch_size=128,
            shuffle=True,
            num_workers=5,
            collate_fn=TrainDataset.collate_fn)
        trainer.train_iterator = BidirectionalOneShotIterator(
            trainer.train_dataloader_head, trainer.train_dataloader_tail)
        epochs = 1500
        epoch_reward, epoch_loss, avg_reward = 0, 0, 0
        for i in range(epochs):
            trainer.generator.train()
            positive_sample, negative_sample, subsampling_weight, mode = next(
                trainer.train_iterator)
            if trainer.args.cuda:
                positive_sample = positive_sample.cuda()  # [batch_size, 3]
                negative_sample = negative_sample.cuda(
                )  # [batch_size, negative_sample_size]
            #$ embed()
            pos, neg, scores, sample_idx, row_idx = trainer.generate(
                positive_sample, negative_sample, mode)
            loss, rewards = trainer.discriminate(pos, neg, mode)
            epoch_reward += torch.sum(rewards)
            epoch_loss += loss
            rewards = rewards - avg_reward

            trainer.generator.zero_grad()
            log_probs = F.log_softmax(scores, dim=1)
            reinforce_loss = torch.sum(
                Variable(rewards) * log_probs[row_idx.cuda(), sample_idx.data])
            reinforce_loss.backward()
            trainer.gen_optimizer.step()
            trainer.generator.eval()
Beispiel #4
0
def main(args):
    #### build dataloader and logger
    word2id = pickle.load(open(args.word2id, "rb"))
    if args.word_num is None:
        args.word_num = len(word2id)
    set_logger(args)
    logging.info(args)
    train_dataset = TrainDataset(args.train_file, args.negative_sample_size, word2id=word2id)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                  num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn)
    train_dataset_neg = TrainDataset(args.train_file, 1, word2id=word2id, if_pos=False)
    train_dataloader_neg = DataLoader(train_dataset_neg, batch_size=args.batch_size, shuffle=True,
                                  num_workers=max(1, args.cpu_num // 2), collate_fn=TrainDataset.collate_fn)

    valid_dataset = TestDataset(args.valid_file)
    valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=True,
                                  num_workers=max(1, args.cpu_num // 2), collate_fn=TestDataset.collate_fn)

    #### build model
    model = model_name2model[args.model_name](args)
    if args.pretrain is not None and os.path.exists(args.pretrain):
        logging.info("using %s as pretrain word embedding" % args.pretrain)
        pretrained = torch.load(args.pretrain)
        model.embedding.weight.data.copy_(pretrained)

    num_steps = len(train_dataloader) * args.num_epochs
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate, weight_decay=0.01)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, num_steps // 2, gamma=0.1)
    warmup_scheduler = None #warmup.UntunedLinearWarmup(optimizer)
    if args.cuda:
        model = model.cuda()

    # embed()
    #### begin training
    logging.info("begin training:")
    best_model, best_f1 = None, 0
    for epoch in range(args.num_epochs):
        model.do_train(model, optimizer, train_dataloader, args, lr_scheduler, warmup_scheduler)
        # model.do_train(model, optimizer, train_dataloader_neg, args)
        if epoch % model.config.valid_epochs == 0:
            metric = model.do_valid(model, valid_dataloader, args)
            log_metrics("epoch %d" % epoch, metric)
            if metric["f1"] > best_f1:
                best_model = model.state_dict()
                best_f1 = metric["f1"]
                torch.save(best_model, os.path.join(args.save_path, "best_%s.pt" % args.name))

    model.load_state_dict(best_model)
    log = model.do_prediction(model, valid_dataloader, args, word2id)
    log_metrics("final", log)
Beispiel #5
0
def train_data_iterator(train_triples, ent_num):
    dataloader_head = DataLoader(TrainDataset(train_triples, ent_num,
                                              config.neg_size, "head-batch"),
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 num_workers=max(0, config.cpu_num // 3),
                                 collate_fn=TrainDataset.collate_fn)
    dataloader_tail = DataLoader(TrainDataset(train_triples, ent_num,
                                              config.neg_size, "tail-batch"),
                                 batch_size=config.batch_size,
                                 shuffle=True,
                                 num_workers=max(0, config.cpu_num // 3),
                                 collate_fn=TrainDataset.collate_fn)
    return BidirectionalOneShotIterator(dataloader_head, dataloader_tail)
Beispiel #6
0
def train(args):
    trainset = TrainDataset()
    trainloader = DataLoader(trainset,
                             batch_size=args.batch_size,
                             shuffle=True)
    vae = VAE().to(DEVICE)
    vae.fit(trainloader, n_epochs=args.num_epochs, lr=args.lr)
Beispiel #7
0
 def load_dataset(self,
                  mode='train',
                  random_scale=True,
                  rotate=True,
                  fliplr=True,
                  fliptb=True):
     if mode == 'train':
         train_set = TrainDataset(os.path.join(self.data_dir,
                                               self.train_dataset),
                                  crop_size=self.crop_size,
                                  scale_factor=self.scale_factor,
                                  random_scale=random_scale,
                                  rotate=rotate,
                                  fliplr=fliplr,
                                  fliptb=fliptb)
         return DataLoader(dataset=train_set,
                           num_workers=self.num_threads,
                           batch_size=self.batch_size,
                           shuffle=True)
     elif mode == 'valid':
         valid_set = DevDataset(
             os.path.join(self.data_dir, self.valid_dataset))
         return DataLoader(dataset=valid_set,
                           num_workers=self.num_threads,
                           batch_size=self.valid_batch_size,
                           shuffle=True)
     elif mode == 'test':
         test_set = TestDataset(
             os.path.join(self.data_dir, self.test_dataset))
         return DataLoader(dataset=test_set,
                           num_workers=self.num_threads,
                           batch_size=self.test_batch_size,
                           shuffle=False)
Beispiel #8
0
 def init_dataset(self):
     train_dataloader_head = DataLoader(TrainDataset(
         self.train_triples, self.entity_count, self.attr_count,
         self.value_count, 512, 'head-batch'),
                                        batch_size=1024,
                                        shuffle=False,
                                        num_workers=4,
                                        collate_fn=TrainDataset.collate_fn)
     train_dataloader_tail = DataLoader(TrainDataset(
         self.train_triples, self.entity_count, self.attr_count,
         self.value_count, 512, 'tail-batch'),
                                        batch_size=1024,
                                        shuffle=False,
                                        num_workers=4,
                                        collate_fn=TrainDataset.collate_fn)
     self.train_iterator = BidirectionalOneShotIterator(
         train_dataloader_head, train_dataloader_tail)
Beispiel #9
0
def get_dataloader(batch_size=256):
    trainset = TrainDataset()
    validset = ValidDataset()
    testset = TestDataset()
    trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=False)
    validloader = DataLoader(validset, batch_size=batch_size, shuffle=False)
    testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)

    return trainloader, validloader, testloader
Beispiel #10
0
def train_data_iterator(train_triples, ent_num):
    modes = ["head-batch", "tail-batch"]
    datasets = [
        DataLoader(TrainDataset(train_triples, ent_num, config.neg_size, mode),
                   batch_size=config.batch_size,
                   shuffle=True,
                   num_workers=4,
                   collate_fn=TrainDataset.collate_fn) for mode in modes
    ]
    return BidirectionalOneShotIterator(datasets[0], datasets[1])
Beispiel #11
0
def run(params,dirs,seed=None,restore_file=None):
    #set random seed to do reproducible experiments
    if seed is not None:
        utils.seed(seed)
    utils.set_logger(os.path.join(dirs.model_dir, 'train.log'))
    logger = logging.getLogger('DeepAR.Train')
    #check cuda is avaliable or not
    use_cuda=torch.cuda.is_available()
    # Set random seeds for reproducible experiments if necessary
    if use_cuda:
        dirs.device = torch.device('cuda:0')
        logger.info('Using Cuda...')
        model = net.Net(params,dirs.device).cuda(dirs.device)
    else:
        dirs.device = torch.device('cpu')
        logger.info('Not using cuda...')
        model = net.Net(params,dirs.device)

    logger = logging.getLogger('DeepAR.Data')
    logger.info('Loading the datasets...')
    train_set = TrainDataset(dirs.data_dir, dirs.dataset)
    vali_set = ValiDataset(dirs.data_dir, dirs.dataset)
    test_set = TestDataset(dirs.data_dir, dirs.dataset)
    train_loader = DataLoader(train_set,batch_size=params.batch_size,pin_memory=False,
                              num_workers=4)
    vali_loader = DataLoader(vali_set,batch_size=params.batch_size,pin_memory=False,
                             sampler=RandomSampler(vali_set),num_workers=4)
    test_loader = DataLoader(test_set,batch_size=params.batch_size,pin_memory=False,
                             sampler=RandomSampler(test_set),num_workers=4)
    logger.info('Data loading complete.')
    logger.info('###############################################\n')

    logger = logging.getLogger('DeepAR.Train')
    logger.info(f'Model: \n{str(model)}')
    logger.info('###############################################\n')

    optimizer = optim.Adam(model.parameters(), lr=params.lr)
    # fetch loss function
    loss_fn = net.loss_fn
    # Train the model
    logger.info('Starting training for {} epoch(s)'.format(params.num_epochs))
    train_and_evaluate(model,train_loader,vali_loader,optimizer,loss_fn,scheduler,params,dirs,restore_file)
    logger.handlers.clear()
    logging.shutdown()

    load_dir = os.path.join(dirs.model_save_dir, 'best.pth.tar')
    if not os.path.exists(load_dir):
        return
    utils.load_checkpoint(load_dir, model)
    out=evaluate(model, loss_fn, test_loader, params, dirs, istest=True)
    test_json_path=os.path.join(dirs.model_dir,'test_results.json')
    utils.save_dict_to_json(out, test_json_path)
Beispiel #12
0
def construct_dataloader(args, train_triples, nentity, nrelation):
    train_dataloader_head = DataLoader(TrainDataset(train_triples, nentity,
                                                    nrelation,
                                                    args.negative_sample_size,
                                                    'head-batch'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=max(1, args.cpu_num // 2),
                                       collate_fn=TrainDataset.collate_fn)

    train_dataloader_tail = DataLoader(TrainDataset(train_triples, nentity,
                                                    nrelation,
                                                    args.negative_sample_size,
                                                    'tail-batch'),
                                       batch_size=args.batch_size,
                                       shuffle=True,
                                       num_workers=max(1, args.cpu_num // 2),
                                       collate_fn=TrainDataset.collate_fn)

    train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                  train_dataloader_tail)

    return train_iterator
Beispiel #13
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)
    
    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    
    #All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        type_dim = args.type_dim,
        gamma=args.gamma,
        gamma_type=args.gamma_type,
        gamma_pair=args.gamma_pair,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, args.pair_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, args.pair_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('type_dim = %d' % args.type_dim)
    logging.info('gamma_type = %f' % args.gamma_type)
    logging.info('alpha_1 = %f' % args.alpha_1)
    logging.info('gamma_pair = %f' % args.gamma_pair)
    logging.info('alpha_2 = %f' % args.alpha_2)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    logging.info('pair_sample_size = %d' % args.pair_sample_size)
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
Beispiel #14
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test) and (
            not args.evaluate_train):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)

    args.save_path = 'log/%s/%s/%s-%s/%s' % (
        args.dataset, args.model, args.hidden_dim, args.gamma,
        time.time()) if args.save_path == None else args.save_path
    writer = SummaryWriter(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    dataset = LinkPropPredDataset(name='ogbl-biokg')
    split_edge = dataset.get_edge_split()
    train_triples, valid_triples, test_triples = split_edge[
        "train"], split_edge["valid"], split_edge["test"]
    nrelation = int(max(train_triples['relation'])) + 1
    entity_dict = dict()
    cur_idx = 0
    for key in dataset[0]['num_nodes_dict']:
        entity_dict[key] = (cur_idx,
                            cur_idx + dataset[0]['num_nodes_dict'][key])
        cur_idx += dataset[0]['num_nodes_dict'][key]
    nentity = sum(dataset[0]['num_nodes_dict'].values())

    evaluator = Evaluator(name=args.dataset)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Dataset: %s' % args.dataset)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    # train_triples = split_dict['train']
    logging.info('#train: %d' % len(train_triples['head']))
    # valid_triples = split_dict['valid']
    logging.info('#valid: %d' % len(valid_triples['head']))
    # test_triples = split_dict['test']
    logging.info('#test: %d' % len(test_triples['head']))

    train_count, train_true_head, train_true_tail = defaultdict(
        lambda: 4), defaultdict(list), defaultdict(list)
    for i in tqdm(range(len(train_triples['head']))):
        head, relation, tail = train_triples['head'][i], train_triples[
            'relation'][i], train_triples['tail'][i]
        head_type, tail_type = train_triples['head_type'][i], train_triples[
            'tail_type'][i]
        train_count[(head, relation, head_type)] += 1
        train_count[(tail, -relation - 1, tail_type)] += 1
        train_true_head[(relation, tail)].append(head)
        train_true_tail[(head, relation)].append(tail)

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding,
        evaluator=evaluator)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        entity_dict = checkpoint['entity_dict']

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch', train_count,
                         train_true_head, train_true_tail, entity_dict),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch', train_count,
                         train_true_head, train_true_tail, entity_dict),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        # logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        # checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        # entity_dict = checkpoint['entity_dict']
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info('learning_rate = %d' % current_learning_rate)

        training_logs = []

        #Training Loop
        for step in range(init_step, args.max_steps):

            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)
            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    kge_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0 and step > 0:  # ~ 41 seconds/saving
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps,
                    'entity_dict': entity_dict
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Train', step, metrics, writer)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0 and step > 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples, args,
                                              entity_dict)
                log_metrics('Valid', step, metrics, writer)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples, args,
                                      entity_dict)
        log_metrics('Valid', step, metrics, writer)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples, args,
                                      entity_dict)
        log_metrics('Test', step, metrics, writer)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        small_train_triples = {}
        indices = np.random.choice(len(train_triples['head']),
                                   args.ntriples_eval_train,
                                   replace=False)
        for i in train_triples:
            if 'type' in i:
                small_train_triples[i] = [train_triples[i][x] for x in indices]
            else:
                small_train_triples[i] = train_triples[i][indices]
        metrics = kge_model.test_step(kge_model,
                                      small_train_triples,
                                      args,
                                      entity_dict,
                                      random_sampling=True)
        log_metrics('Train', step, metrics, writer)
Beispiel #15
0
def main(args):
    # if (not args.do_train) and (not args.do_valid) and (not args.do_test):
    #     raise ValueError('one of train/val/test mode must be choosed.')
    
    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)
    
    # Write logs to checkpoint and console
    set_logger(args)
    
    # with open(os.path.join(args.data_path, 'entities.dict')) as fin:
    #     entity2id = dict()
    #     id2entity = dict()
    #     for line in fin:
    #         eid, entity = line.strip().split('\t')
    #         entity2id[entity] = int(eid)
    #         id2entity[int(eid)] = entity

    # with open(os.path.join(args.data_path, 'relations.dict')) as fin:
    #     relation2id = dict()
    #     id2relation = dict()
    #     for line in fin:
    #         rid, relation = line.strip().split('\t')
    #         relation2id[relation] = int(rid)
    #         id2relation[int(rid)] = relation
    
    # # Read regions for Countries S* datasets
    # if args.countries:
    #     regions = list()
    #     with open(os.path.join(args.data_path, 'regions.list')) as fin:
    #         for line in fin:
    #             region = line.strip()
    #             regions.append(entity2id[region])
    #     args.regions = regions

    '''amazon dataset'''
    with open(os.path.join(args.data_path, 'entity2id.txt')) as fin:
        entity2id = dict()
        id2entity = dict()
        for line in fin:
            if len(line.strip().split('\t')) < 2:
                continue
            entity, eid = line.strip().split('\t')
            entity2id[entity] = int(eid)
            id2entity[int(eid)] = entity

    with open(os.path.join(args.data_path, 'relation2id.txt')) as fin:
        relation2id = dict()
        id2relation = dict()
        for line in fin:
            if len(line.strip().split('\t')) < 2:
                continue
            relation, rid = line.strip().split('\t')
            relation2id[relation] = int(rid)
            id2relation[int(rid)] = relation

    nentity = len(entity2id)
    nrelation = len(relation2id)
    
    args.nentity = nentity
    args.nrelation = nrelation
    
    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    
    # --------------------------------------------------
    # Comments by Meng:
    # During training, pLogicNet will augment the training triplets,
    # so here we load both the augmented triplets (train.txt) for training and
    # the original triplets (train_kge.txt) for evaluation.
    # Also, the hidden triplets (hidden.txt) are also loaded for annotation.
    # --------------------------------------------------
    # train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    # logging.info('#train: %d' % len(train_triples))
    # train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    # logging.info('#train original: %d' % len(train_original_triples))
    # valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'), entity2id, relation2id)
    # logging.info('#valid: %d' % len(valid_triples))
    # test_triples = read_triple(os.path.join(args.data_path, 'test.txt'), entity2id, relation2id)
    # logging.info('#test: %d' % len(test_triples))
    # hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    # logging.info('#hidden: %d' % len(hidden_triples))

    train_triples = read_triple(os.path.join(args.workspace_path, 'train_kge.txt'), entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    train_original_triples = read_triple(os.path.join(args.data_path, 'train.txt'), entity2id, relation2id)
    logging.info('#train original: %d' % len(train_original_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'kg_val_triples_Cell_Phones_and_Accessories.txt'), entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'kg_test_triples_Cell_Phones_and_Accessories.txt'), entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))
    test_candidates = np.load(os.path.join(args.data_path, 'rec_test_candidate100.npz'))['candidates'][:, 1:]
    # test_candidates = np.load('/common/users/yz956/kg/code/OpenDialKG/cand.npy')
    # hidden_triples = read_triple(os.path.join(args.workspace_path, 'hidden.txt'), entity2id, relation2id)
    hidden_triples = read_triple("/common/users/yz956/kg/code/KBRD/data/cpa/cpa/hidden_50.txt", entity2id, relation2id)
    logging.info('#hidden: %d' % len(hidden_triples))
    
    #All true triples
    all_true_triples = train_original_triples + valid_triples + test_triples
    
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding
    )
    
    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' % (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()
    
    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'head-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation, args.negative_sample_size, 'tail-batch'), 
            batch_size=args.batch_size,
            shuffle=True, 
            num_workers=max(1, args.cpu_num//2),
            collate_fn=TrainDataset.collate_fn
        )
        
        train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
        
        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad, kge_model.parameters()), 
            lr=current_learning_rate
        )
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0
    
    step = init_step
    
    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' % args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' % str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' % args.adversarial_temperature)

    if args.record:
        local_path = args.workspace_path
        ensure_dir(local_path)

        opt = vars(args)
        with open(local_path + '/opt.txt', 'w') as fo:
            for key, val in opt.items():
                fo.write('{} {}\n'.format(key, val))
    
    # Set valid dataloader as it would be evaluated during training
    
    if args.do_train:
        training_logs = []
        
        #Training Loop
        for step in range(init_step, args.max_steps):
            
            log = kge_model.train_step(kge_model, optimizer, train_iterator, args)
            
            training_logs.append(log)
            
            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' % (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()), 
                    lr=current_learning_rate
                )
                warm_up_steps = warm_up_steps * 3
            
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step, 
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)
                
            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum([log[metric] for log in training_logs])/len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []
                
            if args.do_valid and (step + 1) % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
                log_metrics('Valid', step, metrics)
        
        save_variable_list = {
            'step': step, 
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)
        
    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics, preds = kge_model.test_step(kge_model, valid_triples, all_true_triples, args)
        log_metrics('Valid', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on validation set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge_valid.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge_valid.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')
    
    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        # metrics, preds = kge_model.test_step(kge_model, test_triples, all_true_triples, args)
        metrics, preds = kge_model.test_step(kge_model, test_triples, test_candidates, all_true_triples, args)
        log_metrics('Test', step, metrics)
        
        # --------------------------------------------------
        # Comments by Meng:
        # Save the prediction results of KGE on test set.
        # --------------------------------------------------

        if args.record:
            # Save the final results
            with open(local_path + '/result_kge.txt', 'w') as fo:
                for metric in metrics:
                    fo.write('{} : {}\n'.format(metric, metrics[metric]))

            # Save the predictions on test data
            with open(local_path + '/pred_kge.txt', 'w') as fo:
                for h, r, t, f, rk, l in preds:
                    fo.write('{}\t{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], f, rk))
                    for e, val in l:
                        fo.write('{}:{:.4f} '.format(id2entity[e], val))
                    fo.write('\n')

    # --------------------------------------------------
    # Comments by Meng:
    # Save the annotations on hidden triplets.
    # --------------------------------------------------

    if args.record:
        # Annotate hidden triplets
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        # with open(local_path + '/annotation.txt', 'w') as fo:
        #     for (h, r, t), s in zip(hidden_triples, scores):
        #         fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))

        # Annotate hidden triplets
        print('annotation')
        
        cand = {}
        with gzip.open('/common/users/yz956/kg/code/KBRD/data/cpa/cpa/kg_test_candidates_Cell_Phones_and_Accessories.txt.gz', 'rt') as f:
            for line in f:
                cells = line.split()
                uid = int(cells[0])
                item_ids = [int(i) for i in cells[1:]]
                cand[uid] = item_ids
        ann, train = [], []
        d = {}
        with open('/common/users/yz956/kg/code/KBRD/data/cpa/cpa/sample_pre.txt') as ft:
            for line in ft:
                line = line.strip().split('\t')
                train.append(line[1:])
        for u in range(61254):
            hiddens = []
            for i in cand[u]:
            # for i in range(61254, 108858):
                hiddens.append((u, 0, i))
            scores = kge_model.infer_step(kge_model, hiddens, args)
            score_np = np.array(scores)
            d = dict(zip(cand[u], scores))
            # d = dict(zip(range(61254, 108858), scores))
            d = sorted(d.items(), key=lambda x: x[1], reverse=True)
            
            # d_50 = d[:50]
            # for idx, t in enumerate(train[u]):
            #     for (tt, prob) in d_50:
            #         if int(t) == tt:
            #             d_50.remove((tt, prob))
            #             d_50.append(d[50 + idx])
            # assert len(d_50) == 50
            # d = {}

            d_50 = d
            ann.append(d_50)
        with open(local_path + '/annotation_1000_htr.txt', 'w') as fo:
            for idx, d in enumerate(ann):
                for (t, score) in d:
                    fo.write(str(idx) + '\t' + str(t) + '\t0\t' + str(score) + '\n')

        # with open(local_path + '/hidden_50_p.txt', 'w') as fo:
        #     for idx, d in enumerate(ann):
        #         for (t, score) in d:
        #             fo.write(str(idx) + '\t' + str(t) + '\t0\n')
        
        scores = kge_model.infer_step(kge_model, hidden_triples, args)
        with open(local_path + '/annotation_htr.txt', 'w') as fo:
            for (h, r, t), s in zip(hidden_triples, scores):
                # fo.write('{}\t{}\t{}\t{}\n'.format(id2entity[h], id2relation[r], id2entity[t], s))
                fo.write('{}\t{}\t{}\t{}\n'.format(str(h), str(t), str(r), s))
    
    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics, preds = kge_model.test_step(kge_model, train_triples, all_true_triples, args)
        log_metrics('Test', step, metrics)
Beispiel #16
0
def main():
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path, 'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_dir=log_path)

    data_path = args.data_path
    train_path = os.path.join(data_path, 'train/label.txt')
    val_path = os.path.join(data_path, 'val/label.txt')
    # dataset_train = TrainDataset(train_path,transform=transforms.Compose([RandomCroper(),RandomFlip()]))
    dataset_train = TrainDataset(train_path,
                                 transform=transforms.Compose(
                                     [Resizer(), PadToSquare()]))
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=8,
                                  batch_size=args.batch,
                                  collate_fn=collater,
                                  shuffle=True)
    # dataset_val = ValDataset(val_path,transform=transforms.Compose([RandomCroper()]))
    dataset_val = ValDataset(val_path,
                             transform=transforms.Compose(
                                 [Resizer(), PadToSquare()]))
    dataloader_val = DataLoader(dataset_val,
                                num_workers=8,
                                batch_size=args.batch,
                                collate_fn=collater)

    total_batch = len(dataloader_train)

    # Create the model
    # if args.depth == 18:
    #     retinaface = model.resnet18(num_classes=2, pretrained=True)
    # elif args.depth == 34:
    #     retinaface = model.resnet34(num_classes=2, pretrained=True)
    # elif args.depth == 50:
    #     retinaface = model.resnet50(num_classes=2, pretrained=True)
    # elif args.depth == 101:
    #     retinaface = model.resnet101(num_classes=2, pretrained=True)
    # elif args.depth == 152:
    #     retinaface = model.resnet152(num_classes=2, pretrained=True)
    # else:
    #     raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    # Create torchvision model
    return_layers = {'layer2': 1, 'layer3': 2, 'layer4': 3}
    retinaface = torchvision_model.create_retinaface(return_layers)

    retinaface = retinaface.cuda()
    retinaface = torch.nn.DataParallel(retinaface).cuda()
    retinaface.training = True

    optimizer = optim.Adam(retinaface.parameters(), lr=1e-3)
    # optimizer = optim.SGD(retinaface.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0005)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    # scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    #scheduler  = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,60], gamma=0.1)

    print('Start to train.')

    epoch_loss = []
    iteration = 0

    for epoch in range(args.epochs):
        retinaface.train()
        #print('Current learning rate:',scheduler.get_lr()[0])
        # retinaface.module.freeze_bn()
        # retinaface.module.freeze_first_layer()

        # Training
        for iter_num, data in enumerate(dataloader_train):
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss, ldm_regression_loss = retinaface(
                [data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            # loss = classification_loss + 1.0 * bbox_regression_loss + 0.5 * ldm_regression_loss
            loss = classification_loss + bbox_regression_loss + ldm_regression_loss

            loss.backward()
            optimizer.step()
            #epoch_loss.append(loss.item())

            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
                    epoch, args.epochs, iter_num, total_batch)
                table_data = [['loss name', 'value'],
                              ['total_loss', str(loss.item())],
                              [
                                  'classification',
                                  str(classification_loss.item())
                              ], ['bbox',
                                  str(bbox_regression_loss.item())],
                              ['landmarks',
                               str(ldm_regression_loss.item())]]
                table = AsciiTable(table_data)
                #table = SingleTable(table_data)
                #table = DoubleTable(table_data)
                log_str += table.table
                print(log_str)
                # write the log to tensorboard
                writer.add_scalars(
                    'losses:', {
                        'total_loss': loss.item(),
                        'cls_loss': classification_loss.item(),
                        'bbox_loss': bbox_regression_loss.item(),
                        'ldm_loss': ldm_regression_loss.item()
                    }, iteration * args.verbose)
                iteration += 1

        #scheduler.step()
        #scheduler.step(np.mean(epoch_loss))

        # Eval
        if epoch % args.eval_step == 0:
            print('-------- RetinaFace Pytorch --------')
            print('Evaluating epoch {}'.format(epoch))
            recall, precision = eval_widerface.evaluate(
                dataloader_val, retinaface)
            print('Recall:', recall)
            print('Precision:', precision)

        # Save model
        if (epoch + 1) % args.save_step == 0:
            torch.save(retinaface.state_dict(),
                       args.save_path + '/model_epoch_{}.pt'.format(epoch + 1))
Beispiel #17
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        # create default save directory
        dt = datetime.datetime.now()
        args.save_path = os.path.join(
            os.environ['LOG_DIR'],
            args.data_path.split('/')[-1], args.model,
            datetime.datetime.now().strftime('%m%d%H%M%S'))
        # raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)
    writer = SummaryWriter(log_dir=args.save_path)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:
                region = line.strip()
                regions.append(entity2id[region])
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('Save Path: {}'.format(args.save_path))
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    # All true triples
    all_true_triples = train_triples + valid_triples + test_triples

    if args.model in EUC_MODELS:
        ModelClass = EKGEModel
    elif args.model in HYP_MODELS:
        ModelClass = HKGEModel
    elif args.model in ONE_2_MANY_E_MODELS:
        ModelClass = O2MEKGEModel
    else:
        raise ValueError('model %s not supported' % args.model)

    if ModelClass != O2MEKGEModel:
        kge_model = ModelClass(
            model_name=args.model,
            nentity=nentity,
            nrelation=nrelation,
            hidden_dim=args.hidden_dim,
            gamma=args.gamma,
            p_norm=args.p_norm,
            dropout=args.dropout,
            entity_embedding_multiple=args.entity_embedding_multiple,
            relation_embedding_multiple=args.relation_embedding_multiple)
    else:
        kge_model = ModelClass(
            model_name=args.model,
            nentity=nentity,
            nrelation=nrelation,
            hidden_dim=args.hidden_dim,
            gamma=args.gamma,
            p_norm=args.p_norm,
            dropout=args.dropout,
            entity_embedding_multiple=args.entity_embedding_multiple,
            relation_embedding_multiple=args.relation_embedding_multiple,
            nsiblings=args.nsib,
            rho=args.rho)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = init_optimizer(kge_model, current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    if args.do_train:
        logging.info('Start Training...')
        logging.info('init_step = %d' % init_step)
        logging.info('hidden_dim = %d' % args.hidden_dim)
        logging.info('learning_rate = %d' % current_learning_rate)
        logging.info('batch_size = %d' % args.batch_size)
        logging.info('negative_adversarial_sampling = %d' %
                     args.negative_adversarial_sampling)

        logging.info('gamma = %f' % args.gamma)
        logging.info('dropout = %f' % args.dropout)
        if args.negative_adversarial_sampling:
            logging.info('adversarial_temperature = %f' %
                         args.adversarial_temperature)

        # Set valid dataloader as it would be evaluated during training
        training_logs = []

        # Training Loop
        for step in range(init_step, args.max_steps):

            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)
            training_logs.append(log)
            write_metrics(writer, step, log, split='train')
            write_metrics(writer, step,
                          {'current_learning_rate': current_learning_rate})

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = init_optimizer(kge_model, current_learning_rate)

                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                write_metrics(writer, step, metrics, split='train')
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)
                write_metrics(writer, step, metrics, split='valid')

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }

        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
Beispiel #18
0
def main(args):
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')

    if not args.do_train and args.init_checkpoint:
        override_config(args)
    elif args.data_path is None:
        raise ValueError('one of init_checkpoint/data_path must be choosed.')

    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')

    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)

    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split('\t')
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)

    train_triples = read_triple(os.path.join(args.data_path, "train.txt"),
                                entity2id, relation2id)
    if args.self_test:
        train_triples = train_triples[len(train_triples) // 5:]
    if args.fake:
        fake_triples = pickle.load(
            open(os.path.join(args.data_path, "fake%s.pkl" % args.fake), "rb"))
        fake = torch.LongTensor(fake_triples)
        train_triples += fake_triples
    else:
        fake_triples = [(0, 0, 0)]
        fake = torch.LongTensor(fake_triples)
    if args.cuda:
        fake = fake.cuda()
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    all_true_triples = train_triples + valid_triples + test_triples

    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding)
    trainer = None
    if args.method == "CLF":
        trainer = ClassifierTrainer(train_triples, fake_triples, args,
                                    kge_model, args.hard)
    elif args.method == "LT":
        trainer = LTTrainer(train_triples, fake_triples, args, kge_model)
    elif args.method == "NoiGAN":
        trainer = NoiGANTrainer(train_triples, fake_triples, args, kge_model,
                                args.hard)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                            kge_model.parameters()),
                                     lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = 0  #checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            # current_learning_rate = checkpoint['current_learning_rate']
            # warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        logging.info('learning_rate = %f' % current_learning_rate)

        training_logs = []

        #Training Loop
        triple2confidence_weights = None
        for step in range(init_step, args.max_steps):
            if args.method == "CLF" and step % args.classify_steps == 0:
                logging.info('Train Classifier')
                metrics = trainer.train_classifier(trainer)
                log_metrics('Classifier', step, metrics)
                metrics = trainer.test_ave_score(trainer)
                log_metrics('Classifier', step, metrics)
                trainer.cal_confidence_weight()
            elif args.method == "NoiGAN" and step % args.classify_steps == 0:
                logging.info('Train NoiGAN')
                trainer.train_NoiGAN(trainer)
                metrics = trainer.test_ave_score(trainer)
                log_metrics('Classifier', step, metrics)
                trainer.cal_confidence_weight()

            log = kge_model.train_step(kge_model,
                                       optimizer,
                                       train_iterator,
                                       args,
                                       trainer=trainer)

            training_logs.append(log)

            if step >= warm_up_steps:
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                                    kge_model.parameters()),
                                             lr=current_learning_rate)
                warm_up_steps = warm_up_steps * 3

            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args,
                           trainer)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args, trainer)

    if trainer is not None:
        logging.info("Evaluating Classifier on Train Dataset")
        metrics = trainer.test_ave_score(trainer)
        log_metrics('Train', step, metrics)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
        # logging.info("\t".join([metric for metric in metrics.values()]))

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
Beispiel #19
0
def main():
    args = ArgParser().parse_args()
    prepare_save_path(args)
    args.neg_sample_size_eval = 1000
    set_global_seed(args.seed)

    init_time_start = time.time()
    dataset = get_dataset(args, args.data_path, args.dataset, args.format,
                          args.delimiter, args.data_files,
                          args.has_edge_importance)
    args.batch_size = get_compatible_batch_size(args.batch_size,
                                                args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    #print(args)
    set_logger(args)

    print("To build training dataset")
    t1 = time.time()
    train_data = TrainDataset(dataset,
                              args,
                              has_importance=args.has_edge_importance)
    print("Training dataset built, it takes %s seconds" % (time.time() - t1))
    args.num_workers = 8  # fix num_worker to 8
    print("Building training sampler")
    t1 = time.time()
    train_sampler_head = train_data.create_sampler(
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        neg_sample_size=args.neg_sample_size,
        neg_mode='head')
    train_sampler_tail = train_data.create_sampler(
        batch_size=args.batch_size,
        num_workers=args.num_workers,
        neg_sample_size=args.neg_sample_size,
        neg_mode='tail')
    train_sampler = NewBidirectionalOneShotIterator(train_sampler_head,
                                                    train_sampler_tail)
    print("Training sampler created, it takes %s seconds" % (time.time() - t1))

    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        print("To create eval_dataset")
        t1 = time.time()
        eval_dataset = EvalDataset(dataset, args)
        print("eval_dataset created, it takes %d seconds" % (time.time() - t1))

    if args.valid:
        if args.num_proc > 1:
            valid_samplers = []
            for i in range(args.num_proc):
                print("creating valid sampler for proc %d" % i)
                t1 = time.time()
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    mode='tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_samplers.append(valid_sampler_tail)
                print(
                    "Valid sampler for proc %d created, it takes %s seconds" %
                    (i, time.time() - t1))
        else:
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                mode='tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            valid_samplers = [valid_sampler_tail]

    for arg in vars(args):
        logging.info('{:20}:{}'.format(arg, getattr(args, arg)))

    print("To create model")
    t1 = time.time()
    model = BaseKEModel(args=args,
                        n_entities=dataset.n_entities,
                        n_relations=dataset.n_relations,
                        model_name=args.model_name,
                        hidden_size=args.hidden_dim,
                        entity_feat_dim=dataset.entity_feat.shape[1],
                        relation_feat_dim=dataset.relation_feat.shape[1],
                        gamma=args.gamma,
                        double_entity_emb=args.double_ent,
                        relation_times=args.ote_size,
                        scale_type=args.scale_type)

    model.entity_feat = dataset.entity_feat
    model.relation_feat = dataset.relation_feat
    print(len(model.parameters()))

    if args.cpu_emb:
        print("using cpu emb\n" * 5)
    else:
        print("using gpu emb\n" * 5)
    optimizer = paddle.optimizer.Adam(learning_rate=args.mlp_lr,
                                      parameters=model.parameters())
    lr_tensor = paddle.to_tensor(args.lr)

    global_step = 0
    tic_train = time.time()
    log = {}
    log["loss"] = 0.0
    log["regularization"] = 0.0
    for step in range(0, args.max_step):
        pos_triples, neg_triples, ids, neg_head = next(train_sampler)
        loss = model.forward(pos_triples, neg_triples, ids, neg_head)

        log["loss"] = loss.numpy()[0]
        if args.regularization_coef > 0.0 and args.regularization_norm > 0:
            coef, nm = args.regularization_coef, args.regularization_norm
            reg = coef * norm(model.entity_embedding.curr_emb, nm)
            log['regularization'] = reg.numpy()[0]
            loss = loss + reg

        loss.backward()
        optimizer.step()
        if args.cpu_emb:
            model.entity_embedding.step(lr_tensor)
        optimizer.clear_grad()
        if (step + 1) % args.log_interval == 0:
            speed = args.log_interval / (time.time() - tic_train)
            logging.info(
                "step: %d, train loss: %.5f, regularization: %.4e, speed: %.2f steps/s"
                % (step, log["loss"], log["regularization"], speed))
            log["loss"] = 0.0
            tic_train = time.time()

        if args.valid and (
                step + 1
        ) % args.eval_interval == 0 and step > 1 and valid_samplers is not None:
            print("Valid begin")
            valid_start = time.time()
            valid_input_dict = test(args,
                                    model,
                                    valid_samplers,
                                    step,
                                    rank=0,
                                    mode='Valid')
            paddle.save(
                valid_input_dict,
                os.path.join(args.save_path, "valid_{}.pkl".format(step)))
            # Save the model for the inference
        if (step + 1) % args.save_step == 0:
            print("The step:{}, save model path:{}".format(
                step + 1, args.save_path))
            model.save_model()
            print("Save model done.")
Beispiel #20
0
value_list = read_ids("data/fr_en/att_value2id_all")

entity_count = len(entity_list)
attr_count = len(attr_list)
value_count = len(value_list)
print("entity:", entity_count, "attr:", attr_count, "value:", value_count)
device = "cuda"
learning_rate = 0.001
tensorboard_log_dir = "./result/log/"
checkpoint_path = "./result/fr_en/checkpoint.tar"
train_set = DBP15kDataset('data/fr_en/att_triple_all', entity_list, value_list)
train_generator = data.DataLoader(train_set, batch_size=512)

train_triples = read_triple("data/fr_en/att_triple_all")
train_dataloader_head = data.DataLoader(
    TrainDataset(train_triples, entity_count, attr_count, value_count, 256, 'head-batch'),
    batch_size=1024,
    shuffle=False,
    num_workers=4,
    collate_fn=TrainDataset.collate_fn
)
train_dataloader_tail = data.DataLoader(
    TrainDataset(train_triples, entity_count, attr_count, value_count, 256, 'tail-batch'),
    batch_size=1024,
    shuffle=False,
    num_workers=4,
    collate_fn=TrainDataset.collate_fn
)
train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

model = TransE(entity_count, attr_count, value_count, device).to(device)
    model = AdaptiveFactorizationNetwork(field_dims=field_dims,
                                         embed_dim=args.embed_dims,
                                         LNN_dim=args.LNN_dim,
                                         mlp_dims=(400, 400, 400),
                                         dropouts=(0, 0, 0))

    #criterion = torch.nn.SmoothL1Loss()
    criterion = torch.nn.L1Loss()
    optimizer = torch.optim.Adam(params=model.parameters(),
                                 lr=args.lr,
                                 weight_decay=1e-6)

    all_data = pd.read_csv(args.input_csv)

    train_data = all_data[all_data.flag < 8]
    val_data = all_data[all_data.flag >= 8]

    train_set = TrainDataset(train_data)
    val_set = TrainDataset(val_data)

    train_loader = DataLoader(train_set, batch_size=args.batch_size)
    val_loader = DataLoader(val_set, batch_size=args.batch_size)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    for _ in range(args.epochs):
        train_one_epoch(model, train_loader, criterion, device, optimizer)
        validate_model(model, val_loader, device)

        torch.save(model.state_dict(), 'model.pth')
Beispiel #22
0
def run(args, logger):
    init_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)

    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size,
                                                args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part
    args.num_workers = 8  # fix num_worker to 8

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_sample_size,
                mode='head',
                num_workers=args.num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_sample_size,
                mode='tail',
                num_workers=args.num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_sample_size,
                                                args.neg_sample_size, True,
                                                dataset.n_entities))

        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_sample_size,
            args.neg_sample_size, True, dataset.n_entities)
    else:  # This is used for debug
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='head',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='tail',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_sample_size,
            args.neg_sample_size, True, dataset.n_entities)

    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        eval_dataset = EvalDataset(dataset, args)

    if args.valid:
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:  # This is used for debug
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
    if args.test:
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_eval,
                    args.neg_sample_size_eval,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=args.num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-head',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_eval,
                args.neg_sample_size_eval,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=args.num_workers,
                rank=0,
                ranks=1)

    # load model
    model = load_model(logger, args, dataset.n_entities, dataset.n_relations)
    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None

    print('Total initialize time {:.3f} seconds'.format(time.time() -
                                                        init_time_start))

    # train
    start = time.time()
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]
                             ] if args.valid else None
            proc = mp.Process(target=train_mp,
                              args=(args, model, train_samplers[i],
                                    valid_sampler, i, rel_parts, cross_rels,
                                    barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)

    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump(
                {
                    'dataset': args.dataset,
                    'model': args.model_name,
                    'emb_size': args.hidden_dim,
                    'max_train_step': args.max_step,
                    'batch_size': args.batch_size,
                    'neg_sample_size': args.neg_sample_size,
                    'lr': args.lr,
                    'gamma': args.gamma,
                    'double_ent': args.double_ent,
                    'double_rel': args.double_rel,
                    'neg_adversarial_sampling': args.neg_adversarial_sampling,
                    'adversarial_temperature': args.adversarial_temperature,
                    'regularization_coef': args.regularization_coef,
                    'regularization_norm': args.regularization_norm
                },
                outfile,
                indent=4)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} : {}'.format(k, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('testing takes {:.3f} seconds'.format(time.time() - start))
Beispiel #23
0
def run(args, logger):
    train_time_start = time.time()
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format,
                          args.data_files)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test

    num_workers = args.num_worker
    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part

    # Automatically set number of OMP threads for each process if it is not provided
    # The value for GPU is evaluated in AWS p3.16xlarge
    # The value for CPU is evaluated in AWS x1.32xlarge
    if args.nomp_thread_per_process == -1:
        if len(args.gpu) > 0:
            # GPU training
            args.num_thread = 4
        else:
            # CPU training
            args.num_thread = 1
    else:
        args.num_thread = args.nomp_thread_per_process

    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='head',
                num_workers=num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='tail',
                num_workers=num_workers,
                shuffle=True,
                exclude_positive=False,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_chunk_size,
                                                args.neg_sample_size, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_chunk_size,
                                                       mode='head',
                                                       num_workers=num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler_tail = train_data.create_sampler(args.batch_size,
                                                       args.neg_sample_size,
                                                       args.neg_chunk_size,
                                                       mode='tail',
                                                       num_workers=num_workers,
                                                       shuffle=True,
                                                       exclude_positive=False)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_chunk_size,
            args.neg_sample_size, True, n_entities)

    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    if args.num_proc > 1:
        num_workers = 1
    if args.valid or args.test:
        if len(args.gpu) > 1:
            args.num_test_proc = args.num_proc if args.num_proc < len(
                args.gpu) else len(args.gpu)
        else:
            args.num_test_proc = args.num_proc
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        # We use a maximum of num_gpu in test stage to save GPU memory.
        if args.num_test_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_test_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_test_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1 or args.async_update:
        model.share_memory()

    print('Total data loading time {:.3f} seconds'.format(time.time() -
                                                          train_time_start))

    # train
    start = time.time()
    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None
    if args.num_proc > 1:
        procs = []
        barrier = mp.Barrier(args.num_proc)
        for i in range(args.num_proc):
            valid_sampler = [valid_sampler_heads[i], valid_sampler_tails[i]
                             ] if args.valid else None
            proc = mp.Process(target=train_mp,
                              args=(args, model, train_samplers[i],
                                    valid_sampler, i, rel_parts, cross_rels,
                                    barrier))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers, rel_parts=rel_parts)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

        # We need to save the model configurations as well.
        conf_file = os.path.join(args.save_emb, 'config.json')
        with open(conf_file, 'w') as outfile:
            json.dump(
                {
                    'dataset': args.dataset,
                    'model': args.model_name,
                    'emb_size': args.hidden_dim,
                    'max_train_step': args.max_step,
                    'batch_size': args.batch_size,
                    'neg_sample_size': args.neg_sample_size,
                    'lr': args.lr,
                    'gamma': args.gamma,
                    'double_ent': args.double_ent,
                    'double_rel': args.double_rel,
                    'neg_adversarial_sampling': args.neg_adversarial_sampling,
                    'adversarial_temperature': args.adversarial_temperature,
                    'regularization_coef': args.regularization_coef,
                    'regularization_norm': args.regularization_norm
                },
                outfile,
                indent=4)

    # test
    if args.test:
        start = time.time()
        if args.num_test_proc > 1:
            queue = mp.Queue(args.num_test_proc)
            procs = []
            for i in range(args.num_test_proc):
                proc = mp.Process(target=test_mp,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            metrics = {}
            logs = []
            for i in range(args.num_test_proc):
                log = queue.get()
                logs = logs + log

            for metric in logs[0].keys():
                metrics[metric] = sum([log[metric]
                                       for log in logs]) / len(logs)
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('test:', time.time() - start)
Beispiel #24
0
def start_worker(args, logger):
    """Start kvclient for training
    """
    init_time_start = time.time()
    time.sleep(WAIT_TIME)  # wait for launch script

    server_namebook = dgl.contrib.read_ip_config(filename=args.ip_config)

    args.machine_id = get_local_machine_id(server_namebook)

    dataset, entity_partition_book, local2global = get_partition_dataset(
        args.data_path, args.dataset, args.format, args.machine_id)

    n_entities = dataset.n_entities
    n_relations = dataset.n_relations

    print('Partition %d n_entities: %d' % (args.machine_id, n_entities))
    print("Partition %d n_relations: %d" % (args.machine_id, n_relations))

    entity_partition_book = F.tensor(entity_partition_book)
    relation_partition_book = get_long_tail_partition(dataset.n_relations,
                                                      args.total_machine)
    relation_partition_book = F.tensor(relation_partition_book)
    local2global = F.tensor(local2global)

    relation_partition_book.share_memory_()
    entity_partition_book.share_memory_()
    local2global.share_memory_()

    train_data = TrainDataset(dataset, args, ranks=args.num_client)
    # if there is no cross partition relaiton, we fall back to strict_rel_part
    args.strict_rel_part = args.mix_cpu_gpu and (train_data.cross_part
                                                 == False)
    args.soft_rel_part = args.mix_cpu_gpu and args.soft_rel_part and train_data.cross_part

    if args.neg_sample_size_eval < 0:
        args.neg_sample_size_eval = dataset.n_entities
    args.batch_size = get_compatible_batch_size(args.batch_size,
                                                args.neg_sample_size)
    args.batch_size_eval = get_compatible_batch_size(args.batch_size_eval,
                                                     args.neg_sample_size_eval)

    args.num_workers = 8  # fix num_workers to 8
    train_samplers = []
    for i in range(args.num_client):
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='head',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False,
            rank=i)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_sample_size,
            mode='tail',
            num_workers=args.num_workers,
            shuffle=True,
            exclude_positive=False,
            rank=i)
        train_samplers.append(
            NewBidirectionalOneShotIterator(train_sampler_head,
                                            train_sampler_tail,
                                            args.neg_sample_size,
                                            args.neg_sample_size, True,
                                            n_entities))

    dataset = None

    model = load_model(logger, args, n_entities, n_relations)
    model.share_memory()

    print('Total initialize time {:.3f} seconds'.format(time.time() -
                                                        init_time_start))

    rel_parts = train_data.rel_parts if args.strict_rel_part or args.soft_rel_part else None
    cross_rels = train_data.cross_rels if args.soft_rel_part else None

    procs = []
    barrier = mp.Barrier(args.num_client)
    for i in range(args.num_client):
        proc = mp.Process(target=dist_train_test,
                          args=(args, model, train_samplers[i],
                                entity_partition_book, relation_partition_book,
                                local2global, i, rel_parts, cross_rels,
                                barrier))
        procs.append(proc)
        proc.start()
    for proc in procs:
        proc.join()
Beispiel #25
0
    bs = args.bs
    lr = args.lr
    width = args.width
    depth = args.depth
    epochs = args.epochs
    verbose = args.verbose

    # set random seed
    set_random_seed(seed)
    # define the device for training
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # get training ids
    train_ids, valid_ids, test_ids = get_ids(65)
    # define dataloaders
    if args.dynamic:
        train_data = TrainDataset(train_ids)
        train_iter = DataLoader(train_data,
                                batch_size=bs,
                                num_workers=6,
                                sampler=LoopSampler)
    else:
        train_data = StaticTrainDataset(train_ids)
        train_iter = DataLoader(train_data,
                                batch_size=bs,
                                num_workers=6,
                                shuffle=True)

    train_tdata = TestDataset(train_ids)
    valid_tdata = TestDataset(valid_ids)
    test_tdata = TestDataset(test_ids)
Beispiel #26
0
    def __init__(self, args, writer):
        self.args = args
        self.writer = writer  # object for saving current status
        self.device = torch.device(
            'cuda:0' if torch.cuda.is_available() else 'cpu')
        self.df_train = pd.read_csv(
            os.path.join(args.dataset_path, 'train.csv'))
        self.df_reference = pd.read_csv(
            os.path.join(args.dataset_path, 'reference.csv'))
        self.df_query = pd.read_csv(
            os.path.join(args.dataset_path, 'query.csv'))

        if args.mode == 'train':
            # Train mode에서는 softmax 레이어를 제거한 상태로 일반적인 분류기를 학습합니다.
            print('Training Start...')
            print('Total epoches:', args.epochs)
            stratified_folds = StratifiedShuffleSplit(n_splits=1,
                                                      test_size=0.1)

            for (train_index, validation_index) in stratified_folds.split(
                    self.df_train, self.df_train['class']):
                df_train = self.df_train.iloc[train_index, :].reset_index()
                df_validation = self.df_train.iloc[
                    validation_index, :].reset_index()

                # Load custom transforms
                self.transform = custom_transforms(model_name=args.backbone,
                                                   target_size=args.image_size)

                # Load dataset for train
                train_dataset = TrainDataset(
                    os.path.join(args.dataset_path, 'Images'),
                    df_train,
                    transforms=self.transform['train'])
                train_loader = DataLoader(train_dataset,
                                          batch_size=args.batch_size,
                                          shuffle=True)

                # Load dataset for validation
                validation_dataset = TrainDataset(
                    os.path.join(args.dataset_path, 'Images'),
                    df_validation,
                    transforms=self.transform['validation'])
                validation_loader = DataLoader(validation_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)

                print('Dataset class : ', self.args.class_num)
                print(
                    "Holdout Train dataset length: {} / Holdout Validation dataset length: {}"
                    .format(str(len(train_dataset)),
                            str(len(validation_dataset))))

                # Define model
                model, evaluator, lr_scheduler, optimizer = self.define_network(
                    args=args, length_train_dataloader=len(train_dataset))
                model.to(self.device)

                # Define Criterion
                criterion = nn.CrossEntropyLoss(
                )  # buildLosses(cuda=args.cuda).build_loss(mode=args.loss_type)
                print('optimizer : ', type(optimizer), ' lr scheduler : ',
                      type(lr_scheduler))

                # Train the model
                self.train_model(args=args,
                                 model=model,
                                 optimizer=optimizer,
                                 scheduler=lr_scheduler,
                                 criterion=criterion,
                                 train_loader=train_loader,
                                 df_validation=df_validation,
                                 validation_dataset=validation_dataset,
                                 validation_loader=validation_loader,
                                 weight_file_name='weight_best.pth')
                del (model, evaluator, lr_scheduler, optimizer)
                gc.collect()

        elif args.mode == 'test':  # Test 모드로 설정 후 모델 infer 수행 및 score 계산
            '''
            train 후, 모델은 Query image를 인풋으로 받고 각 reference images와 유사도를 비교한 후 query image와 유사한 순서대로 reference image filename을 정렬해 보여줍니다. 
            - Reference image는 query 이미지와 유사도 비교의 대상이 되는 이미지입니다.
            - Query image는 모델에 넣어 이미지 검색을 수행할 이미지입니다. 
            '''

            print('Inferring Start...')
            query_path = args.dataset_path + '/QueryImages'
            reference_path = args.dataset_path + '/Images/'
            model_path = args.test_model_savepath
            all_dir = glob.glob(model_path + '*', recursive=True)

            weight_list = []
            for _dir in all_dir:
                weight_list.append(os.path.join(_dir, os.listdir(_dir)[0]))

            db = [
                os.path.join(reference_path, path)
                for path in os.listdir(reference_path)
            ]  # 'reference_path' 디렉토리 안의 reference image 파일 이름을 db 리스트에 append합니다.
            queries = [
                v.split('/')[-1].split('.')[0] for v in os.listdir(query_path)
            ]  # 'query_path'의 각 파일들로부터 파일 이름만 남깁니다. (e.g. 'yooa1', 'yeji3', ...)
            db = [v.split('/')[-1].split('.')[0]
                  for v in db]  # 'reference_path'의 각 파일들로부터 파일 이름만 남깁니다.
            queries.sort()
            db.sort()

            transform = custom_transforms(model_name=args.backbone,
                                          target_size=args.image_size)
            ref_dataset = TestDataset(reference_path,
                                      self.df_reference,
                                      transforms=transform['test'])
            ref_loader = DataLoader(ref_dataset,
                                    batch_size=args.test_batch_size,
                                    shuffle=False,
                                    num_workers=4,
                                    pin_memory=True)
            query_dataset = TestDataset(query_path,
                                        self.df_query,
                                        transforms=transform['test'])
            query_loader = DataLoader(query_dataset,
                                      batch_size=args.test_batch_size,
                                      shuffle=False,
                                      num_workers=4,
                                      pin_memory=True)

            model = initialize_model(args=None,
                                     model_name=args.backbone,
                                     num_classes=args.class_num)
            model.load_state_dict(torch.load(weight_list[0]), strict=False)
            print('model loaded:', weight_list[0])
            model.eval()
            model.to(device)

            # Reference와 Query image에 대한 feature vector를 학습이 다 된 모델로부터 추출합니다.
            reference_paths, reference_vecs = self.batch_process(
                model, ref_loader)
            query_paths, query_vecs = self.batch_process(model, query_loader)
            assert query_paths == queries and reference_paths == db, "order of paths should be same"

            # 두 벡터(query_vecs, reference_vecs)의 유사도를 계산합니다.
            sim_matrix = self.calculate_sim_matrix(query_vecs, reference_vecs)

            indices = np.argsort(sim_matrix, axis=1)
            indices = np.flip(indices, axis=1)

            retrieval_results = {}
            # Evaluation: mean average precision (mAP)
            # You can change this part to fit your evaluation skim
            for (i, query) in enumerate(query_paths):
                query = query.split('/')[-1].split('.')[0]
                ranked_list = [
                    reference_paths[k].split('/')[-1].split('.')[0]
                    for k in indices[i]
                ]
                ranked_list = ranked_list[:1000]

                retrieval_results[query] = ranked_list

            print('Retrieval done.')
            print(retrieval_results)
        else:
            print("wrong mode input.")
            raise NotImplementedError
Beispiel #27
0
def main(args):
    #什么模式
    if (not args.do_train) and (not args.do_valid) and (not args.do_test):
        raise ValueError('one of train/val/test mode must be choosed.')
    #是否要用config修改一些命令行参数
    if args.init_checkpoint:  # 如果init_checkpoint有值(是一个路径)就从config里面修改一些参数
        override_config(args)  #这个函数里面要用到init_checkpoint,因此进了这个if就说明它有值就不会报错
    elif args.data_path is None:  #override函数里,如果path是none就从config里读,如果就没进上一个if,就不能从config里读,默认是none,后来自己改成默认值,按说这里不指定的话就会raise这个错误
        raise ValueError('one of init_checkpoint/data_path must be choosed.')
    #训练但是没给保存路径
    if args.do_train and args.save_path is None:
        raise ValueError('Where do you want to save your trained model?')
    #有保存路径但是文件夹不存在,就创建一个
    if args.save_path and not os.path.exists(args.save_path):
        os.makedirs(args.save_path)

    # Write logs to checkpoint and console
    set_logger(args)
    #里面是所有实体/关系的代码 1 xx  2 xx  3 xx
    #读出来的entity2id,relation2id是键为实体/关系代码,值为序号的字典
    with open(os.path.join(args.data_path, 'entities.dict')) as fin:
        entity2id = dict()
        for line in fin:
            eid, entity = line.strip().split(
                '\t')  #1 xx   eid=str(1),entity=xx
            entity2id[entity] = int(eid)

    with open(os.path.join(args.data_path, 'relations.dict')) as fin:
        relation2id = dict()
        for line in fin:
            rid, relation = line.strip().split('\t')
            relation2id[relation] = int(rid)

    # Read regions for Countries S* datasets
    if args.countries:
        regions = list()
        #contries数据集里面会有一个文件是regions.list其他的数据集没有
        with open(os.path.join(args.data_path, 'regions.list')) as fin:
            for line in fin:  #只有5行
                region = line.strip()
                regions.append(entity2id[region])  #这里的值应该是区域的序号
        args.regions = regions

    nentity = len(entity2id)
    nrelation = len(relation2id)

    args.nentity = nentity
    args.nrelation = nrelation

    logging.info('Model: %s' % args.model)
    logging.info('Data Path: %s' % args.data_path)
    logging.info('#entity: %d' % nentity)
    logging.info('#relation: %d' % nrelation)
    #开始获取训练验证测试数据集,并打印size
    train_triples = read_triple(os.path.join(args.data_path, 'train.txt'),
                                entity2id, relation2id)
    logging.info('#train: %d' % len(train_triples))
    valid_triples = read_triple(os.path.join(args.data_path, 'valid.txt'),
                                entity2id, relation2id)
    logging.info('#valid: %d' % len(valid_triples))
    test_triples = read_triple(os.path.join(args.data_path, 'test.txt'),
                               entity2id, relation2id)
    logging.info('#test: %d' % len(test_triples))

    # All true triples
    all_true_triples = train_triples + valid_triples + test_triples
    #构造模型
    kge_model = KGEModel(
        model_name=args.model,
        nentity=nentity,
        nrelation=nrelation,
        hidden_dim=args.hidden_dim,
        gamma=args.gamma,
        double_entity_embedding=args.double_entity_embedding,
        double_relation_embedding=args.double_relation_embedding)

    logging.info('Model Parameter Configuration:')
    for name, param in kge_model.named_parameters():
        logging.info('Parameter %s: %s, require_grad = %s' %
                     (name, str(param.size()), str(param.requires_grad)))

    if args.cuda:
        kge_model = kge_model.cuda()

    if args.do_train:
        # Set training dataloader iterator
        train_dataloader_head = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'head-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_dataloader_tail = DataLoader(
            TrainDataset(train_triples, nentity, nrelation,
                         args.negative_sample_size, 'tail-batch'),
            batch_size=args.batch_size,
            shuffle=True,
            num_workers=max(1, args.cpu_num // 2),
            collate_fn=TrainDataset.collate_fn)

        train_iterator = BidirectionalOneShotIterator(train_dataloader_head,
                                                      train_dataloader_tail)

        # Set training configuration
        current_learning_rate = args.learning_rate
        optimizer = torch.optim.Adam(
            filter(lambda p: p.requires_grad,
                   kge_model.parameters()),  #fitter操作,只优化requires_grad为true的
            lr=current_learning_rate)
        if args.warm_up_steps:
            warm_up_steps = args.warm_up_steps
        else:
            warm_up_steps = args.max_steps // 2

    if args.init_checkpoint:
        # Restore model from checkpoint directory
        logging.info('Loading checkpoint %s...' % args.init_checkpoint)
        checkpoint = torch.load(
            os.path.join(args.init_checkpoint, 'checkpoint'))
        init_step = checkpoint['step']
        kge_model.load_state_dict(checkpoint['model_state_dict'])
        if args.do_train:
            current_learning_rate = checkpoint['current_learning_rate']
            warm_up_steps = checkpoint['warm_up_steps']
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    else:
        logging.info('Ramdomly Initializing %s Model...' % args.model)
        init_step = 0

    step = init_step

    logging.info('Start Training...')
    logging.info('init_step = %d' % init_step)
    logging.info('learning_rate = %d' % current_learning_rate)
    logging.info('batch_size = %d' % args.batch_size)
    logging.info('negative_adversarial_sampling = %d' %
                 args.negative_adversarial_sampling)
    logging.info('hidden_dim = %d' % args.hidden_dim)
    logging.info('gamma = %f' % args.gamma)
    logging.info('negative_adversarial_sampling = %s' %
                 str(args.negative_adversarial_sampling))
    if args.negative_adversarial_sampling:
        logging.info('adversarial_temperature = %f' %
                     args.adversarial_temperature)

    # Set valid dataloader as it would be evaluated during training

    if args.do_train:
        training_logs = []

        # Training Loop
        for step in range(init_step, args.max_steps):
            #train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)
            log = kge_model.train_step(kge_model, optimizer, train_iterator,
                                       args)

            training_logs.append(log)
            #动态调整学习率
            if step >= warm_up_steps:  #大于warm_up_steps后学习率变为原来的1/10
                current_learning_rate = current_learning_rate / 10
                logging.info('Change learning_rate to %f at step %d' %
                             (current_learning_rate, step))
                optimizer = torch.optim.Adam(
                    filter(lambda p: p.requires_grad, kge_model.parameters()),
                    lr=current_learning_rate  #更新优化器里的学习率
                )
                warm_up_steps = warm_up_steps * 3  #更新warm_up_steps
            #每隔save_checkpoint_steps保存一次模型
            if step % args.save_checkpoint_steps == 0:
                save_variable_list = {
                    'step': step,
                    'current_learning_rate': current_learning_rate,
                    'warm_up_steps': warm_up_steps
                }
                save_model(kge_model, optimizer, save_variable_list, args)

            if step % args.log_steps == 0:
                metrics = {}
                for metric in training_logs[0].keys():
                    metrics[metric] = sum(
                        [log[metric]
                         for log in training_logs]) / len(training_logs)
                log_metrics('Training average', step, metrics)
                training_logs = []

            if args.do_valid and step % args.valid_steps == 0:
                logging.info('Evaluating on Valid Dataset...')
                metrics = kge_model.test_step(kge_model, valid_triples,
                                              all_true_triples, args)
                log_metrics('Valid', step, metrics)

        save_variable_list = {
            'step': step,
            'current_learning_rate': current_learning_rate,
            'warm_up_steps': warm_up_steps
        }
        save_model(kge_model, optimizer, save_variable_list, args)

    if args.do_valid:
        logging.info('Evaluating on Valid Dataset...')
        metrics = kge_model.test_step(kge_model, valid_triples,
                                      all_true_triples, args)
        log_metrics('Valid', step, metrics)

    if args.do_test:
        logging.info('Evaluating on Test Dataset...')
        metrics = kge_model.test_step(kge_model, test_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)

    if args.evaluate_train:
        logging.info('Evaluating on Training Dataset...')
        metrics = kge_model.test_step(kge_model, train_triples,
                                      all_true_triples, args)
        log_metrics('Test', step, metrics)
def main():
    precision_global = 0
    args = get_args()
    if not os.path.exists(args.save_path):
        os.mkdir(args.save_path)
    log_path = os.path.join(args.save_path, 'log')
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    writer = SummaryWriter(log_dir=log_path)

    data_path = args.data_path
    train_path = os.path.join(
        data_path,
        'retina-train-splitTrain.txt')  #"train\\label.txt")#'train.txt')
    val_path = os.path.join(
        data_path, "retina-train-splitTest.txt"
    )  #"retina-train-splitTest.txt") #'retina-val.txt')##'val.txt')
    # train_path = os.path.join(data_path,'train\\label.txt')#"train\\label.txt")#'train.txt')
    # val_path = os.path.join(data_path,'val\\label.txt')#"val\\label.txt")#'val.txt')
    # dataset_train = TrainDataset(train_path,transform=transforms.Compose([RandomCroper(),RandomFlip()]))
    dataset_train = TrainDataset(train_path,
                                 transform=transforms.Compose(
                                     [Resizer(), PadToSquare()]))
    dataloader_train = DataLoader(dataset_train,
                                  num_workers=6,
                                  batch_size=args.batch,
                                  collate_fn=collater,
                                  shuffle=True)
    # dataset_val = ValDataset(val_path,transform=transforms.Compose([RandomCroper()]))
    dataset_val = ValDataset(val_path,
                             transform=transforms.Compose(
                                 [Resizer(), PadToSquare()]))
    dataloader_val = DataLoader(dataset_val,
                                num_workers=8,
                                batch_size=args.batch,
                                collate_fn=collater)

    total_batch = len(dataloader_train)

    # Create the model
    # if args.depth == 18:
    #     retinaface = model.resnet18(num_classes=2, pretrained=True)
    # elif args.depth == 34:
    #     retinaface = model.resnet34(num_classes=2, pretrained=True)
    # elif args.depth == 50:
    #     retinaface = model.resnet50(num_classes=2, pretrained=True)
    # elif args.depth == 101:
    #     retinaface = model.resnet101(num_classes=2, pretrained=True)
    # elif args.depth == 152:
    #     retinaface = model.resnet152(num_classes=2, pretrained=True)
    # else:
    #     raise ValueError('Unsupported model depth, must be one of 18, 34, 50, 101, 152')

    # Create torchvision model
    return_layers = {'layer2': 1, 'layer3': 2, 'layer4': 3}
    retinaface = torchvision_model.create_retinaface(return_layers)

    # Load trained model
    if (args.model_path is not None):
        retina_dict = retinaface.state_dict()
        pre_state_dict = torch.load(args.model_path)
        pretrained_dict = {
            k[7:]: v
            for k, v in pre_state_dict.items() if k[7:] in retina_dict
        }
        retinaface.load_state_dict(pretrained_dict)

    retinaface = retinaface.cuda()
    retinaface = torch.nn.DataParallel(retinaface).cuda()
    retinaface.training = True

    optimizer = optim.Adam(retinaface.parameters(), lr=1e-3)
    # optimizer = optim.SGD(retinaface.parameters(), lr=1e-2, momentum=0.9, weight_decay=0.0005)
    # scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=3, verbose=True)
    # scheduler  = optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1)
    #scheduler  = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10,30,60], gamma=0.1)

    #performance detect
    # print('-------- RetinaFace Pytorch --------')
    # recall, precision = eval_widerface.evaluate(dataloader_val, retinaface)
    # print('Recall:', recall)
    # print('Precision:', precision, "best Precision:", precision_global)

    print('Start to train.')

    epoch_loss = []
    iteration = 0

    for epoch in range(args.epochs):
        retinaface.train()

        # Training
        for iter_num, data in enumerate(dataloader_train):
            #ff = data["img"].numpy()
            #print(ff[0][1][320][320])
            optimizer.zero_grad()
            classification_loss, bbox_regression_loss, ldm_regression_loss = retinaface(
                [data['img'].cuda().float(), data['annot']])
            classification_loss = classification_loss.mean()
            bbox_regression_loss = bbox_regression_loss.mean()
            ldm_regression_loss = ldm_regression_loss.mean()

            # loss = classification_loss + 1.0 * bbox_regression_loss + 0.5 * ldm_regression_loss
            loss = classification_loss + bbox_regression_loss + ldm_regression_loss

            loss.backward()
            optimizer.step()

            if iter_num % args.verbose == 0:
                log_str = "\n---- [Epoch %d/%d, Batch %d/%d] ----\n" % (
                    epoch, args.epochs, iter_num, total_batch)
                table_data = [['loss name', 'value'],
                              ['total_loss', str(loss.item())],
                              [
                                  'classification',
                                  str(classification_loss.item())
                              ], ['bbox',
                                  str(bbox_regression_loss.item())],
                              ['landmarks',
                               str(ldm_regression_loss.item())]]
                table = AsciiTable(table_data)
                log_str += table.table
                print(log_str)
                # write the log to tensorboard
                writer.add_scalar('losses:', loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('class losses:', classification_loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('box losses:', bbox_regression_loss.item(),
                                  iteration * args.verbose)
                writer.add_scalar('landmark losses:',
                                  ldm_regression_loss.item(),
                                  iteration * args.verbose)
                iteration += 1

        # Eval
        if epoch % args.eval_step == 0:
            print('-------- RetinaFace Pytorch --------')
            print('Evaluating epoch {}'.format(epoch))
            recall, precision = eval_widerface.evaluate(
                dataloader_val, retinaface)
            if (precision_global < precision):
                precision_global = precision
                torch.save(
                    retinaface.state_dict(), args.save_path +
                    '/model_Best_epoch_{}.pt'.format(epoch + 1))
            print('Recall:', recall)
            print('Precision:', precision, "best Precision:", precision_global)

            writer.add_scalar('Recall:', recall, epoch * args.eval_step)
            writer.add_scalar('Precision:', precision, epoch * args.eval_step)

        # Save model
        if (epoch + 1) % args.save_step == 0:
            torch.save(retinaface.state_dict(),
                       args.save_path + '/model_epoch_{}.pt'.format(epoch + 1))

    writer.close()
Beispiel #29
0
def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities
    args.eval_filter = not args.no_eval_filter
    if args.neg_deg_sample_eval:
        assert not args.eval_filter, "if negative sampling based on degree, we can't filter positive edges."

    # When we generate a batch of negative edges from a set of positive edges,
    # we first divide the positive edges into chunks and corrupt the edges in a chunk
    # together. By default, the chunk size is equal to the negative sample size.
    # Usually, this works well. But we also allow users to specify the chunk size themselves.
    if args.neg_chunk_size < 0:
        args.neg_chunk_size = args.neg_sample_size
    if args.neg_chunk_size_valid < 0:
        args.neg_chunk_size_valid = args.neg_sample_size_valid
    if args.neg_chunk_size_test < 0:
        args.neg_chunk_size_test = args.neg_sample_size_test

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='chunk-head',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                args.neg_chunk_size,
                mode='chunk-tail',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail,
                                                args.neg_chunk_size, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            mode='chunk-head',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            args.neg_chunk_size,
            mode='chunk-tail',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, args.neg_chunk_size, True,
            n_entities)

    # for multiprocessing evaluation, we don't need to sample multiple batches at a time
    # in each process.
    num_workers = args.num_worker
    if args.num_proc > 1:
        num_workers = 1
    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    args.neg_chunk_size_valid,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                args.neg_chunk_size_valid,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-head',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    args.neg_chunk_size_test,
                    args.eval_filter,
                    mode='chunk-tail',
                    num_workers=num_workers,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-head',
                num_workers=num_workers,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                args.neg_chunk_size_test,
                args.eval_filter,
                mode='chunk-tail',
                num_workers=num_workers,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            rel_parts = train_data.rel_parts if args.rel_part else None
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]
                              ] if args.valid else None
            proc = mp.Process(target=train,
                              args=(args, model, train_samplers[i], i,
                                    rel_parts, valid_samplers))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
        start = time.time()
        if args.num_proc > 1:
            queue = mp.Queue(args.num_proc)
            procs = []
            for i in range(args.num_proc):
                proc = mp.Process(target=test,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ], i, 'Test', queue))
                procs.append(proc)
                proc.start()

            total_metrics = {}
            for i in range(args.num_proc):
                metrics = queue.get()
                for k, v in metrics.items():
                    if i == 0:
                        total_metrics[k] = v / args.num_proc
                    else:
                        total_metrics[k] += v / args.num_proc
            for k, v in metrics.items():
                print('Test average {} at [{}/{}]: {}'.format(
                    k, args.step, args.max_step, v))

            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])
        print('test:', time.time() - start)
def run(args, logger):
    # load dataset and samplers
    dataset = get_dataset(args.data_path, args.dataset, args.format)
    n_entities = dataset.n_entities
    n_relations = dataset.n_relations
    if args.neg_sample_size_test < 0:
        args.neg_sample_size_test = n_entities

    train_data = TrainDataset(dataset, args, ranks=args.num_proc)
    if args.num_proc > 1:
        train_samplers = []
        for i in range(args.num_proc):
            train_sampler_head = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-head',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_sampler_tail = train_data.create_sampler(
                args.batch_size,
                args.neg_sample_size,
                mode='PBG-tail',
                num_workers=args.num_worker,
                shuffle=True,
                exclude_positive=True,
                rank=i)
            train_samplers.append(
                NewBidirectionalOneShotIterator(train_sampler_head,
                                                train_sampler_tail, True,
                                                n_entities))
    else:
        train_sampler_head = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-head',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler_tail = train_data.create_sampler(
            args.batch_size,
            args.neg_sample_size,
            mode='PBG-tail',
            num_workers=args.num_worker,
            shuffle=True,
            exclude_positive=True)
        train_sampler = NewBidirectionalOneShotIterator(
            train_sampler_head, train_sampler_tail, True, n_entities)

    if args.valid or args.test:
        eval_dataset = EvalDataset(dataset, args)
    if args.valid:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            valid_sampler_heads = []
            valid_sampler_tails = []
            for i in range(args.num_proc):
                valid_sampler_head = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    mode='PBG-head',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_tail = eval_dataset.create_sampler(
                    'valid',
                    args.batch_size_eval,
                    args.neg_sample_size_valid,
                    mode='PBG-tail',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                valid_sampler_heads.append(valid_sampler_head)
                valid_sampler_tails.append(valid_sampler_tail)
        else:
            valid_sampler_head = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
            valid_sampler_tail = eval_dataset.create_sampler(
                'valid',
                args.batch_size_eval,
                args.neg_sample_size_valid,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
    if args.test:
        # Here we want to use the regualr negative sampler because we need to ensure that
        # all positive edges are excluded.
        if args.num_proc > 1:
            test_sampler_tails = []
            test_sampler_heads = []
            for i in range(args.num_proc):
                test_sampler_head = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    mode='PBG-head',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_tail = eval_dataset.create_sampler(
                    'test',
                    args.batch_size_eval,
                    args.neg_sample_size_test,
                    mode='PBG-tail',
                    num_workers=args.num_worker,
                    rank=i,
                    ranks=args.num_proc)
                test_sampler_heads.append(test_sampler_head)
                test_sampler_tails.append(test_sampler_tail)
        else:
            test_sampler_head = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                mode='PBG-head',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)
            test_sampler_tail = eval_dataset.create_sampler(
                'test',
                args.batch_size_eval,
                args.neg_sample_size_test,
                mode='PBG-tail',
                num_workers=args.num_worker,
                rank=0,
                ranks=1)

    # We need to free all memory referenced by dataset.
    eval_dataset = None
    dataset = None
    # load model
    model = load_model(logger, args, n_entities, n_relations)

    if args.num_proc > 1:
        model.share_memory()

    # train
    start = time.time()
    if args.num_proc > 1:
        procs = []
        for i in range(args.num_proc):
            valid_samplers = [valid_sampler_heads[i], valid_sampler_tails[i]
                              ] if args.valid else None
            proc = mp.Process(target=train,
                              args=(args, model, train_samplers[i],
                                    valid_samplers))
            procs.append(proc)
            proc.start()
        for proc in procs:
            proc.join()
    else:
        valid_samplers = [valid_sampler_head, valid_sampler_tail
                          ] if args.valid else None
        train(args, model, train_sampler, valid_samplers)
    print('training takes {} seconds'.format(time.time() - start))

    if args.save_emb is not None:
        if not os.path.exists(args.save_emb):
            os.mkdir(args.save_emb)
        model.save_emb(args.save_emb, args.dataset)

    # test
    if args.test:
        if args.num_proc > 1:
            procs = []
            for i in range(args.num_proc):
                proc = mp.Process(target=test,
                                  args=(args, model, [
                                      test_sampler_heads[i],
                                      test_sampler_tails[i]
                                  ]))
                procs.append(proc)
                proc.start()
            for proc in procs:
                proc.join()
        else:
            test(args, model, [test_sampler_head, test_sampler_tail])