def train_and_send(global_model_weights, current_epoch, IDS_df):
    device = 'cpu'
    if torch.cuda.is_available():
        device = 'cuda'

    # Defining the DNN model
    input_size = model_input_size
    model = MLP(input_size)
    model.load_state_dict(torch.load(global_model_weights))
    model.to(device)

    # Cross Entropy Loss
    error = nn.CrossEntropyLoss().to(device)

    # Adam Optimizer
    learning_rate = 0.001
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=learning_rate,
                                 weight_decay=0.01)

    model, loss = train_model_stratified(model, optimizer, error, device,
                                         current_epoch, IDS_df)

    # Encode model weights and send
    model.to('cpu')
    model_str = encode_weights(model)
    remote_mqttclient.publish(TRAINED_MODEL_TOPIC,
                              payload=model_str,
                              qos=2,
                              retain=False)
    remote_mqttclient.publish(TRAINED_LOSS_TOPIC,
                              payload=str(loss),
                              qos=2,
                              retain=False)
def train(config_files, run_number):
    max_accuracy = []
    max_validation_accuracy = []
    for n in range(1, 2):
        X, y = get_data_train(n, config_files, run_number)
        input_size, hidden_size, output_size = X.shape[1], 16, 8
        model = MLP(input_size, hidden_size, output_size)
        model.to(device)
        X, y = X.to(device), y.to(device)
        epochs = 20
        accuracy = []
        test_accuracy = []
        for i in range(epochs):
            output_i, loss = train_optim(model, y, X)
            print("epoch {}".format(i))
            print("accuracy = ", np.sum(output_i == y.cpu().numpy()) / y.size())
            print("loss: {}".format(loss))
            accuracy.append((np.sum(output_i == y.cpu().numpy()) / y.size())[0])
            if not os.path.isdir('checkpoint'):
                os.mkdir('checkpoint')
            torch.save(model.state_dict(), "checkpoint/MLP_model_{}_train.pwf".format(i))
            test_accuracy.append(validate(n, config_files, run_number, model))
            torch.save(model.state_dict(), "checkpoint/MLP_model_{}_validate.pwf".format(i))

        plot_accuracy_n_print(accuracy, max_accuracy, n, run_number, 'train')
        plot_accuracy_n_print(test_accuracy, max_validation_accuracy, n, run_number, 'validate')
def model_fn(model_dir):
    """Load the PyTorch model from the `model_dir` directory."""
    print("Loading model.")

    # First, load the parameters used to create the model.
    model_info = {}
    model_info_path = os.path.join(model_dir, 'model_info.pth')
    with open(model_info_path, 'rb') as f:
        model_info = torch.load(f)

    print("model_info: {}".format(model_info))

    # Determine the device and construct the model.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = MLP(model_info['dim_input'], model_info['dim_hidden'],
                model_info['dim_output'])

    # Load the stored model parameters.
    model_path = os.path.join(model_dir, 'model.pth')
    with open(model_path, 'rb') as f:
        model.load_state_dict(torch.load(f))

    # prep for testing
    model.to(device).eval()

    print("Done loading model.")
    return model
Пример #4
0
def gpu_thread(load, memory_queue, process_queue, common_dict, worker):
    # the only thread that has an access to the gpu, it will then perform all the NN computation
    import psutil
    p = psutil.Process()
    p.cpu_affinity([worker])
    import signal
    signal.signal(signal.SIGINT, signal.SIG_IGN)
    try:
        print('process started with pid: {} on core {}'.format(
            os.getpid(), worker),
              flush=True)
        model = MLP(parameters.OBS_SPACE, parameters.ACTION_SPACE)
        model.to(parameters.DEVICE)
        # optimizer = optim.Adam(model.parameters(), lr=5e-5)
        # optimizer = optim.SGD(model.parameters(), lr=3e-2)
        optimizer = optim.RMSprop(model.parameters(), lr=1e-4)
        epochs = 0
        if load:
            checkpoint = torch.load('./model/walker.pt')
            model.load_state_dict(checkpoint['model_state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            epochs = checkpoint['epochs']
        observations = torch.Tensor([]).to(parameters.DEVICE)
        rewards = torch.Tensor([]).to(parameters.DEVICE)
        actions = torch.Tensor([]).to(parameters.DEVICE)
        probs = torch.Tensor([]).to(parameters.DEVICE)
        common_dict['epoch'] = epochs
        while True:
            memory_full, observations, rewards, actions, probs = \
                destack_memory(memory_queue, observations, rewards, actions, probs)
            destack_process(model, process_queue, common_dict)
            if len(observations) > parameters.MAXLEN or memory_full:
                epochs += 1
                print('-' * 60 + '\n        epoch ' + str(epochs) + '\n' +
                      '-' * 60)
                run_epoch(epochs, model, optimizer, observations, rewards,
                          actions, probs)
                observations = torch.Tensor([]).to(parameters.DEVICE)
                rewards = torch.Tensor([]).to(parameters.DEVICE)
                actions = torch.Tensor([]).to(parameters.DEVICE)
                probs = torch.Tensor([]).to(parameters.DEVICE)
                torch.save(
                    {
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'epochs': epochs
                    }, './model/walker.pt')
                common_dict['epoch'] = epochs
    except Exception as e:
        print(e)
        print('saving before interruption', flush=True)
        torch.save(
            {
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'epochs': epochs
            }, './model/walker.pt')
Пример #5
0
def train_net(net,
              device,
              epochs=100,
              batch_size=32,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              ):
    
def get_args():
    parser = argparse.ArgumentParser(description='Train the PointRender on images and pre-processed labels',
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('-e', '--epochs', metavar='E', type=int, default=5,
                        help='Number of epochs', dest='epochs')
    parser.add_argument('-b', '--batch-size', metavar='B', type=int, nargs='?', default=1,
                        help='Batch size', dest='batchsize')
    parser.add_argument('-l', '--learning-rate', metavar='LR', type=float, nargs='?', default=0.1,
                        help='Learning rate', dest='lr')
    parser.add_argument('-f', '--load', dest='load', type=str, default=False,
                        help='Load model from a .pth file')
    parser.add_argument('-v', '--validation', dest='val', type=float, default=10.0,
                        help='Percent of the data that is used as validation (0-100)')

    return parser.parse_args()


if __name__ == '__main__':
    args = get_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    net = MLP(input_voxel = 1, n_classes = 3)

    if args.load:
        net.lead_state_dict(
            torch.load(args.load, map_location=device)
        )

    net.to(device=device)

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  val_percent=args.val / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
            sys.exit(0)
        except SystemExit:
            os._exit(0)
Пример #6
0
def load_model(save_path):
    # torch.save(to_save, save_path)
    model = MLP(len(vocab), HIDDEN_SIZE, num_classes, device=device)

    checkpoint = torch.load(save_path + '/best_model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    epoch = checkpoint['epoch']

    # move the model to GPU if has one
    model.to(device)

    # need this for dropout
    model.eval()
    return model
Пример #7
0
def main(dataset, dim, layers, lr, reg, epochs, batchsize):
    n_user = overlap_user(dataset)
    print(n_user)
    logging.info(str(n_user))
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    mf_s, mf_t = load_model(dataset, dim)
    mapping = MLP(dim, layers)
    mf_s = mf_s.to(device)
    mf_t = mf_t.to(device)
    mapping = mapping.to(device)
    opt = torch.optim.Adam(mapping.parameters(), lr=lr, weight_decay=reg)
    mse_loss = nn.MSELoss()

    start = time()
    for epoch in range(epochs):
        loss_sum = 0
        for users in batch_user(n_user, batchsize):
            us = torch.tensor(users).long()
            us = us.to(device)
            u = mf_s.get_embed(us)
            y = mf_t.get_embed(us)
            loss = train(mapping, opt, mse_loss, u, y)
            loss_sum += loss
        print('Epoch %d [%.1f] loss = %f' % (epoch, time()-start, loss_sum))
        logging.info('Epoch %d [%.1f] loss = %f' %
                     (epoch, time()-start, loss_sum))
        start = time()

    mfile = 'pretrain/%s/Mapping.pth.tar' % dataset
    torch.save(mapping.state_dict(), mfile)
    print('save [%.1f]' % (time()-start))
    logging.info('save [%.1f]' % (time()-start))
Пример #8
0
def main():
    parser = ArgumentParser(description='train a MLP model')
    parser.add_argument('INPUT', type=str, help='path to input')
    parser.add_argument('EMBED', type=str, help='path to embedding')
    parser.add_argument('--gpu', '-g', default=-1, type=int, help='gpu number')
    args = parser.parse_args()

    word_to_id = word2id(args.INPUT)
    embedding = id2embedding(args.EMBED, word_to_id)

    train_loader = MyDataLoader(args.INPUT,
                                word_to_id,
                                batch_size=5000,
                                shuffle=True,
                                num_workers=1)
    # インスタンスを作成
    net = MLP(word_to_id, embedding)
    optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)

    gpu_id = args.gpu
    device = torch.device("cuda:{}".format(gpu_id) if gpu_id >= 0 else "cpu")
    net = net.to(device)

    epochs = 5
    log_interval = 10
    for epoch in range(1, epochs + 1):
        net.train()  # おまじない (Dropout などを使う場合に効く)
        for batch_idx, (ids, mask, labels) in enumerate(train_loader):
            # data shape: (batchsize, 1, 28, 28)

            ids, mask, labels = ids.to(device), mask.to(device), labels.to(
                device)
            optimizer.zero_grad(
            )  # 最初に gradient をゼロで初期化; これを呼び出さないと過去の gradient が蓄積されていく
            output = net(ids, mask)
            output2 = F.softmax(output, dim=1)
            loss = F.binary_cross_entropy(output2[:, 1],
                                          labels.float())  # 損失を計算
            loss.backward()
            optimizer.step()  # パラメータを更新

            # 途中経過の表示
            if batch_idx % log_interval == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(ids), len(train_loader.dataset),
                    10 * batch_idx / len(train_loader), loss.item()))
Пример #9
0
		return sp, self.discrete_freq


if __name__ == '__main__':
	args = docopt(__doc__)
	print("Command line args:\n", args)
	numlayer = int(args['-l'])
	numunit = int(args['-u'])
	model_path = args['-m']
	gpuid = int(args['-g'])
	
	dtype = torch.float
	device = torch.device("cuda:"+str(gpuid) if gpuid>=0 else "cpu")
	
	FFTSIZE = 1024
	FS = 16000 # [Hz]
	
	model = MLP(in_dim=FFTSIZE//2+1, out_dim=FFTSIZE//2+1, numlayer=numlayer, numunit=numunit)
	model.load_state_dict(torch.load(model_path))
	model = model.to(device)
	model.eval()
	
	f02sp = F02SP(FFTSIZE,FS)
	f0 = 0.1 * np.arange(200,5000+1) # input, 0.1~800 [Hz]
	sp, discrete_freq = f02sp.get_sp(f0)
	input_sequence = discrete_freq / f0[:,np.newaxis]
	input_sequence = torch.from_numpy(input_sequence).to(dtype).to(device)
	
	pred_sp = model(input_sequence)
	pred_sp = pred_sp.cpu().data.numpy()
Пример #10
0
    #
    print(f"\t✅ start training and evaluating process\n")
    # -----------------------------------------------------------
    valid_loss_min = np.Inf
    criterion = torch.nn.CrossEntropyLoss()
    start_time = time.time()
    history = {
        "train_loss": [],
        "valid_loss": [],
        "train_acc": [],
        "valid_acc": []
    }

    for e in range(args.epochs):
        loggings = TrainEvalCNN(
            net.to(device),
            device,
            e,
            train_iter,
            valid_iter,
            optimizer=optimizer,
            criterion=criterion) if optimizer else TrainEvalMLP(
                net.to(device),
                device,
                e,
                train_iter,
                valid_iter,
                criterion=criterion)
        train_loss, train_acc, valid_loss, valid_acc = loggings[0], loggings[
            1], loggings[2], loggings[3]
        history["train_loss"].append(train_loss)
Пример #11
0
        for idx, params in enumerate(HYPER_PARA):
            BATCH_SIZE, DROP_RT, LR, EPOCHS = params
            best_roc = []
            best_prc = []

            data_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True)

            for run in range(1):
                model = MLP(NUM_ELEM, EMBEDDING_DIM,  HIDDEN_DIM_ADD_ON, HIDDEN_DIM, NUM_CLS, NUM_LYS, ADD_ON_FEATS, 100,
                            DROP_RT)
                loss_func = nn.CrossEntropyLoss()
                optimizer = optim.Adam(model.parameters(), lr=LR)

                device = torch.device('cuda:0')
                model.to(device)

                last_roc = -1
                last_prc = -1
                epochs_no_imprv = 0
                for epoch in range(EPOCHS):
                    model.train()
                    epoch_loss = 0
                    batch = tqdm(data_loader)
                    for elem, label, lengths, feats in batch:
                        optimizer.zero_grad()
                        prediction = model(elem, lengths, feats)
                        # loss = torch.mean(F.cross_entropy(prediction, label, reduction='none')
                        #                   * (torch.ones_like(label, dtype=torch.float, device=device) +
                        #                      torch.tensor(label.clone().detach(), dtype=torch.float,
                        #                                   device=device) * 8))
        random.seed(run)
        torch.cuda.manual_seed_all(run)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    # CLASSIFIER
    if args.use_conv:
        if args.imprint:

            model = ResNet18_imprint(num_classes=args.n_tasks*5)
            model.seen_classes = []
        else:
            model = ResNet18(args.n_classes, nf=20, input_size=args.input_size)
    else:
        model = MLP(args)
    if args.cuda:
        model = model.to(args.device)

    opt = torch.optim.SGD(model.parameters(), lr=args.lr)
    buffer = Buffer(args)
    if run == 0:
        print("number of classifier parameters:",
                sum([np.prod(p.size()) for p in model.parameters()]))
        print("buffer parameters: ", np.prod(buffer.bx.size()))

    #----------
    # Task Loop

    for task, tr_loader in enumerate(train_loader):

        sample_amt = 0
Пример #13
0
    def run(self, epoch=None):
        '''
        Run a training loop.
        '''
        if epoch is None:
            epoch = self.args.n_epoch
        print('mbsize: {}, hidden size: {}, layer: {}, dropout: {}'.format(
            self.mb, self.hidden, self.layer, self.dropout),
              file=self.file)
        # Init the model, the optimizer and some structures for logging
        self.reset()

        model = MLP(self.hidden, self.layer, self.dropout)
        #print(model)
        model.reset_parameters()
        #model.cuda()
        model.to(self.args.device)
        #model = torch.nn.DataParallel(model)
        opt = optim.Adam(model.parameters())

        acc = 0  # best dev. acc.
        accc = 0  # test acc. at the time of best dev. acc.
        e = -1  # best dev iteration/epoch

        times = []
        losses = []
        ftime = []
        btime = []
        utime = []

        print('Initial evaluation on dev set:')
        self._evaluate(model, self.devloader, 'dev')

        start = time.time()
        # training loop
        for t in range(epoch):
            print('{}:'.format(t), end='', file=self.file, flush=True)
            # train

            #start = torch.cuda.Event(True)
            #end = torch.cuda.Event(True)
            #start.record()
            loss, ft, bt, ut = self._train(model, opt)
            #end.record()
            #end.synchronize()
            #ttime = start.elapsed_time(end)
            print("(wall time: {:.1f} sec) ".format(time.time() - start),
                  end='')
            #times.append(ttime)
            losses.append(loss)
            ftime.append(ft)
            btime.append(bt)
            utime.append(ut)
            # predict
            curacc = self._evaluate(model, self.devloader, 'dev')
            #if curacc > acc:
            #    e = t
            #    acc = curacc
            #    accc = self._evaluate(model, self.testloader, '    test')
        etime = [sum(t) for t in zip(ftime, btime, utime)]
        print('test acc: {:.2f}'.format(
            self._evaluate(model, self.testloader, '    test')))
        print('best on val set - ${:.2f}|{:.2f} at {}'.format(acc, accc, e),
              file=self.file,
              flush=True)
        print('', file=self.file)
Пример #14
0
def train(args):
    # Prepare data, model, device, loss, metric, logger
    mag_dataset = MAGDatasetSlim(name="", path=args.data)
    pretrained_embedding = mag_dataset.g_full.ndata['x'].numpy()
    vocab_size, embed_dim = pretrained_embedding.shape
    if args.arch == "MLP":
        train_loader = EdgeDataLoader(data_path=args.data, mode="train", batch_size=args.bs_train, negative_size=args.ns_train)
        validation_loader = EdgeDataLoader(data_path=args.data, mode="validation", batch_size=args.bs_train, negative_size=args.ns_validation)
        model = MLP(vocab_size, embed_dim, first_hidden=1000, second_hidden=500, activation=nn.LeakyReLU(), pretrained_embedding=pretrained_embedding)
    elif args.arch == "DeepSetMLP":
        train_loader = SubGraphDataLoader(data_path=args.data, mode="train", batch_size=args.bs_train, negative_size=args.ns_train)
        validation_loader = SubGraphDataLoader(data_path=args.data, mode="validation", batch_size=args.bs_train, negative_size=args.ns_validation)
        model = DeepSetMLP(vocab_size, embed_dim, first_hidden=1500, second_hidden=1000, activation=nn.LeakyReLU(), pretrained_embedding=pretrained_embedding)
    else:
        train_loader = AnchorParentDataLoader(data_path=args.data, mode="train", batch_size=args.bs_train, negative_size=args.ns_train)
        validation_loader = AnchorParentDataLoader(data_path=args.data, mode="validation", batch_size=args.bs_train, negative_size=args.ns_validation)
        model = DeepAPGMLP(vocab_size, embed_dim, first_hidden=2000, second_hidden=1000, activation=nn.LeakyReLU(), pretrained_embedding=pretrained_embedding)
    
    if args.device == "-1":
        device = torch.device("cpu")
    else:
        device = torch.device(f"cuda:{args.device}")
    loss_fn = bce_loss
    metric_fn = [macro_averaged_rank, batched_topk_hit_1, batched_topk_hit_3, batched_topk_hit_5, batched_scaled_MRR]

    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, amsgrad=True)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)
    model = model.to(device)

    # Start training
    start = time.time()
    model.train()
    mnt_best, mnt_mode, mnt_metric = 1e10, "min", "macro_averaged_rank"
    not_improved_count = 0
    for epoch in range(args.max_epoch):
        total_loss = 0
        total_metrics = np.zeros(len(metric_fn))
        for batch_idx, examples in enumerate(train_loader):
            if len(examples) == 3:
                optimizer.zero_grad()
                parents, children, labels = examples[0].to(device), examples[1].to(device), examples[2].to(device)
                prediction = model(parents, children)
            elif len(examples) == 4:
                optimizer.zero_grad()
                parents, siblings, children, labels = examples[0].to(device), examples[1].to(device), examples[2].to(device), examples[3].to(device)
                prediction = model(parents, siblings, children)
            else:
                optimizer.zero_grad()
                parents, siblings, grand_parents, children, labels = examples[0].to(device), examples[1].to(device), examples[2].to(device), examples[3].to(device), examples[4].to(device)
                prediction = model(parents, siblings, grand_parents, children)
                
            loss = loss_fn(prediction, labels)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            total_metrics += eval_metrics(metric_fn, prediction.detach(), labels.detach())
        
        total_loss = total_loss / len(train_loader)
        total_metrics = (total_metrics/ len(train_loader)).tolist()
        print(f"Epoch {epoch}: loss: {total_loss}")
        for i in range(len(metric_fn)):
            print(f"    {metric_fn[i].__name__}: {total_metrics[i]}")

        # validation and early stopping
        if (epoch+1) % args.save_period == 0:
            best = False
            val_loss, val_metrics = valid_epoch(model, device, validation_loader, loss_fn, metric_fn)
            scheduler.step(val_metrics[mnt_metric])
            print(f"    Validation loss: {val_loss}")
            for i in range(len(metric_fn)):
                print(f"    Validation {metric_fn[i].__name__}: {val_metrics[metric_fn[i].__name__]}")
            improved = (mnt_mode == 'min' and val_metrics[mnt_metric] <= mnt_best) or (mnt_mode == 'max' and val_metrics[mnt_metric] >= mnt_best)
            if improved:
                mnt_best = val_metrics[mnt_metric]
                not_improved_count = 0
                best = True
            else:
                not_improved_count += 1
            
            if not_improved_count > args.early_stop:
                print(f"Validation performance didn\'t improve for {args.early_stop} epochs. Training stops.")
                break

            save_checkpoint(model, epoch, optimizer, mnt_best, args.checkpoint_dir, save_best=best)
    end = time.time()
    print(f"Finish training in {end-start} seconds")
Пример #15
0
class Test():
    def __init__(self, config_path):
        config = configparser.ConfigParser()
        config.read(config_path)

        self.save_dir = Path(config.get("general", "save_dir"))
        if not self.save_dir.exists():
            self.save_dir.mkdir(parents=True)
        self.clf_th = config.getfloat("general", "clf_th")

        self.mlp_model_path = config.get("model", "mlp")
        assert Path(self.mlp_model_path).exists()

        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        bert_config_path = config.get("bert", "config_path")
        assert Path(bert_config_path).exists()
        self.bert_config = LongformerConfig.from_json_file(bert_config_path)
        self.max_seq_length = self.bert_config.max_position_embeddings - 2
        self.bert_tokenizer = LongformerTokenizer.from_pretrained(
            'allenai/longformer-base-4096')
        # bert_tokenizer_path = config.get("bert", "tokenizer_path")
        # assert Path(bert_config_path).exists()
        # self.bert_tokenizer = LongformerTokenizer.from_pretrained(bert_tokenizer_path)
        bert_model_path = config.get("bert", "model_path")
        assert Path(bert_model_path).exists()
        self.bert_model = LongformerModel.from_pretrained(
            bert_model_path, config=self.bert_config)
        self.bert_model.to(self.device)
        self.bert_model.eval()

        gold_dir = Path(config.get("data", "gold_dir"))
        assert Path(gold_dir).exists()
        self.gold_dataset = ConllDataset(gold_dir)
        target_dir = Path(config.get("data", "target_dir"))
        assert Path(target_dir).exists()
        self.target_dataset = ConllDataset(target_dir)

    def transforms(self, example, label_list, is_gold):
        feature = convert_single_example(example, label_list,
                                         self.max_seq_length,
                                         self.bert_tokenizer)
        label_ids = feature.label_ids
        label_map = feature.label_map
        if is_gold:
            gold_labels = [-1] * self.max_seq_length
            # Get "Element" or "Main" token indices
            for i, lid in enumerate(label_ids):
                if lid == label_map['B-Element']:
                    gold_labels[i] = 0
                elif lid == label_map['B-Main']:
                    gold_labels[i] = 1
                elif lid in (label_map['I-Element'], label_map['I-Main']):
                    gold_labels[i] = 2
                elif lid == label_map['X']:
                    gold_labels[i] = 3
            gold_labels = gold_labels
        else:
            gold_labels = [-1] * self.max_seq_length
            # Get "Element" or "Main" token indices
            for i, lid in enumerate(label_ids):
                if lid == label_map['B-Element']:
                    gold_labels[i] = 0
                elif lid == label_map['I-Element']:
                    gold_labels[i] = 2
                elif lid == label_map['X']:
                    gold_labels[i] = 3
            gold_labels = gold_labels
        # flush data to bert model
        input_ids = torch.tensor(feature.input_ids).unsqueeze(0).to(
            self.device)
        with torch.no_grad():
            bert_output = self.bert_model(input_ids)
        # lstm (ignore padding parts)
        bert_fv = bert_output[0]
        input_ids = torch.tensor(feature.input_ids)
        label_ids = torch.tensor(feature.label_ids)
        return bert_fv, input_ids, label_ids, label_map, gold_labels

    def load_model(self):
        # MLP
        self.mlp = MLP(self.bert_config.hidden_size)
        self.mlp.load_state_dict(torch.load(self.mlp_model_path))
        self.mlp.to(self.device)
        self.mlp.eval()

    def eval(self):
        self.load_model()

        correct_save_dir = self.save_dir / "correct"
        if not correct_save_dir.exists():
            correct_save_dir.mkdir(parents=True)
        incorrect_save_dir = self.save_dir / "incorrect"
        if not incorrect_save_dir.exists():
            incorrect_save_dir.mkdir(parents=True)

        tp, fp, tn, fn = 0, 0, 0, 0
        with torch.no_grad():
            for gold_data, target_data in tqdm(
                    zip(self.gold_dataset, self.target_dataset)):
                # flush to Bert
                gold_fname, gold_example = gold_data
                target_fname, target_example = target_data
                if not gold_fname == target_fname:
                    import pdb
                    pdb.set_trace()
                assert gold_fname == target_fname

                _, _, _, _, gold_labels = self.transforms(
                    gold_example, self.gold_dataset.label_list, is_gold=True)
                fvs, input_ids, label_ids, label_map, pred_labels = self.transforms(
                    target_example,
                    self.target_dataset.label_list,
                    is_gold=False)

                # extract Element/Main tokens
                is_correct = True
                _, ent_gold_labels, golds_mask = Trainer.extract_tokens(
                    fvs.squeeze(0), gold_labels)
                golds = {}
                if len(ent_gold_labels) >= 1:
                    i = 0
                    while True:
                        try:
                            ent_start = golds_mask.index(i)
                        except ValueError:
                            break
                        for n, j in enumerate(golds_mask[ent_start:]):
                            if j != i:
                                ent_end = (ent_start + n - 1)
                                break
                        golds[(ent_start, ent_end)] = ent_gold_labels[i]
                        i += 1

                ents, ent_pred_labels, preds_mask = Trainer.extract_tokens(
                    fvs.squeeze(0), pred_labels)

                preds = {}
                if len(ent_pred_labels) >= 1:
                    i = 0
                    while True:
                        try:
                            ent_start = preds_mask.index(i)
                        except ValueError:
                            break
                        for n, j in enumerate(preds_mask[ent_start:]):
                            if j != i:
                                ent_end = (ent_start + n - 1)
                                break
                        preds[(ent_start, ent_end)] = ent_pred_labels[i]
                        i += 1
                for gold_span, gold_label in golds.items():
                    if gold_span not in preds.keys():
                        if gold_label == 1:
                            fn += 1
                            is_correct = False

                ents_pred = [0] * len(ents)
                for i, pred in enumerate(preds):
                    # convert to torch.tensor
                    inputs = torch.empty(
                        [len(ents[i]),
                         self.bert_config.hidden_size]).to(self.device)
                    for j, token in enumerate(ents[i]):
                        inputs[j, :] = token

                    inputs = torch.mean(inputs, dim=0, keepdim=True)
                    outputs = self.mlp(inputs)

                    if pred in golds.keys():
                        target = golds[pred]
                        if target == 1:
                            if outputs < self.clf_th:
                                fn += 1
                                is_correct = False
                            else:
                                tp += 1
                        else:
                            if outputs < self.clf_th:
                                tn += 1
                            else:
                                fp += 1
                                is_correct = False
                    else:
                        if outputs < self.clf_th:
                            pass
                        else:
                            fp += 1
                            is_correct = False

                    outputs_ = outputs.to('cpu').detach().numpy().copy()
                    if np.all(outputs_ > self.clf_th):
                        ents_pred[i] = 1

                if is_correct:
                    save_dir = correct_save_dir
                else:
                    save_dir = incorrect_save_dir
                save_path = save_dir / (target_fname + ".conll")
                lines = []
                elem_cnt = -1
                for i in range(len(target_example.text)):
                    text = target_example.text[i]
                    label = target_example.label[i]
                    start = target_example.start[i]
                    end = target_example.end[i]
                    if label == "B-Element":
                        elem_cnt += 1
                        if ents_pred[elem_cnt] == 1:
                            lines.append(f"B-Main\t{start}\t{end}\t{text}")
                        elif ents_pred[elem_cnt] == 0:
                            lines.append(f"{label}\t{start}\t{end}\t{text}")
                    elif label == "I-Element":
                        if ents_pred[elem_cnt] == 1:
                            lines.append(f"I-Main\t{start}\t{end}\t{text}")
                        elif ents_pred[elem_cnt] == 0:
                            lines.append(f"{label}\t{start}\t{end}\t{text}")
                    else:
                        lines.append(f"{label}\t{start}\t{end}\t{text}")

                with save_path.open("w") as f:
                    f.write("\n".join(lines))

        return Score(tp, fp, tn, fn).calc_score()
Пример #16
0
def test(args):
    # setup multiprocessing instance
    torch.multiprocessing.set_sharing_strategy('file_system')

    # setup data_loader instances
    if args.arch == "MLP":
        test_data_loader = EdgeDataLoader(mode="test",
                                          data_path=args.data,
                                          batch_size=1,
                                          shuffle=True,
                                          num_workers=4,
                                          batch_type="large_batch")
    elif args.arch == "DeepSetMLP":
        test_data_loader = SubGraphDataLoader(mode="test",
                                              data_path=args.data,
                                              batch_size=1,
                                              shuffle=True,
                                              num_workers=4,
                                              batch_type="large_batch")
    elif args.arch == "DeepAPGMLP":
        test_data_loader = AnchorParentDataLoader(mode="test",
                                                  data_path=args.data,
                                                  batch_size=1,
                                                  shuffle=True,
                                                  num_workers=4,
                                                  batch_type="large_batch")

    # setup device
    device = torch.device(
        f'cuda:{args.device}' if torch.cuda.is_available() else 'cpu')

    # load model
    if args.arch == "MLP":
        model = MLP(vocab_size=29654,
                    embed_dim=250,
                    first_hidden=1000,
                    second_hidden=500,
                    activation=nn.LeakyReLU())
        # model = MLP(vocab_size=431416, embed_dim=250, first_hidden=1000, second_hidden=500, activation=nn.LeakyReLU())
    elif args.arch == "DeepSetMLP":
        model = DeepSetMLP(vocab_size=29654,
                           embed_dim=250,
                           first_hidden=1500,
                           second_hidden=1000,
                           activation=nn.LeakyReLU())
        # model = DeepSetMLP(vocab_size=431416, embed_dim=250, first_hidden=1500, second_hidden=1000, activation=nn.LeakyReLU())
    elif args.arch == "DeepAPGMLP":
        model = DeepAPGMLP(vocab_size=29654,
                           embed_dim=250,
                           first_hidden=2000,
                           second_hidden=1000,
                           activation=nn.LeakyReLU())
    checkpoint = torch.load(args.resume)
    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()

    # get function handles of loss and metrics
    loss_fn = bce_loss
    metric_fn = [
        macro_averaged_rank, batched_topk_hit_1, batched_topk_hit_3,
        batched_topk_hit_5, batched_scaled_MRR
    ]

    # start evaluation on test data
    total_loss = 0.0
    total_metrics = torch.zeros(len(metric_fn))

    with torch.no_grad():
        for batched_examples in tqdm(test_data_loader):
            energy_scores = []
            all_labels = []
            if len(batched_examples) == 3:
                batched_parents, batched_children, batched_labels = batched_examples[
                    0], batched_examples[1], batched_examples[2]
                for parents, children, labels in zip(batched_parents,
                                                     batched_children,
                                                     batched_labels):
                    parents, children = parents.to(device), children.to(device)
                    prediction = model(parents, children).to(device)
                    loss = loss_fn(prediction, labels.to(device))
                    total_loss += loss.item()
                    energy_scores.extend(prediction.squeeze_().tolist())
                    all_labels.extend(labels.tolist())
            elif len(batched_examples) == 4:
                batched_parents, batched_siblings, batched_children, batched_labels = batched_examples[
                    0], batched_examples[1], batched_examples[
                        2], batched_examples[3]
                for parents, siblings, children, labels in zip(
                        batched_parents, batched_siblings, batched_children,
                        batched_labels):
                    parents, siblings, children = parents.to(
                        device), siblings.to(device), children.to(device)
                    prediction = model(parents, siblings, children).to(device)
                    loss = loss_fn(prediction, labels.to(device))
                    total_loss += loss.item()
                    energy_scores.extend(prediction.squeeze_().tolist())
                    all_labels.extend(labels.tolist())
            elif len(batched_examples) == 5:
                batched_parents, batched_siblings, batched_grand_parents, batched_children, batched_labels = batched_examples[
                    0], batched_examples[1], batched_examples[
                        2], batched_examples[3], batched_examples[4]
                for parents, siblings, grand_parents, children, labels in zip(
                        batched_parents, batched_siblings,
                        batched_grand_parents, batched_children,
                        batched_labels):
                    parents, siblings, grand_parents, children = parents.to(
                        device), siblings.to(device), grand_parents.to(
                            device), children.to(device)
                    prediction = model(parents, siblings, grand_parents,
                                       children).to(device)
                    loss = loss_fn(prediction, labels.to(device))
                    total_loss += loss.item()
                    energy_scores.extend(prediction.squeeze_().tolist())
                    all_labels.extend(labels.tolist())

            energy_scores = torch.tensor(energy_scores).unsqueeze_(1)
            all_labels = torch.tensor(all_labels)

            # computing metrics on test set
            for i, metric in enumerate(metric_fn):
                total_metrics[i] += metric(energy_scores, all_labels)

    n_samples = test_data_loader.n_samples
    print(f"Test loss: {total_loss / n_samples}")
    for i in range(len(metric_fn)):
        print(
            f"{metric_fn[i].__name__} : {total_metrics[i].item() / n_samples}")
Пример #17
0
def main():
    # check cuda
    device = f'cuda:{args.gpu}' if torch.cuda.is_available() and args.gpu >= 0 else 'cpu'
    # load data
    dataset = DglNodePropPredDataset(name=args.dataset)
    evaluator = Evaluator(name=args.dataset)

    split_idx = dataset.get_idx_split()
    g, labels = dataset[0] # graph: DGLGraph object, label: torch tensor of shape (num_nodes, num_tasks)
    
    if args.dataset == 'ogbn-arxiv':
        g = dgl.to_bidirected(g, copy_ndata=True)
        
        feat = g.ndata['feat']
        feat = (feat - feat.mean(0)) / feat.std(0)
        g.ndata['feat'] = feat

    g = g.to(device)
    feats = g.ndata['feat']
    labels = labels.to(device)

    # load masks for train / validation / test
    train_idx = split_idx["train"].to(device)
    valid_idx = split_idx["valid"].to(device)
    test_idx = split_idx["test"].to(device)

    n_features = feats.size()[-1]
    n_classes = dataset.num_classes
    
    # load model
    if args.model == 'mlp':
        model = MLP(n_features, args.hid_dim, n_classes, args.num_layers, args.dropout)
    elif args.model == 'linear':
        model = MLPLinear(n_features, n_classes)
    else:
        raise NotImplementedError(f'Model {args.model} is not supported.')

    model = model.to(device)
    print(f'Model parameters: {sum(p.numel() for p in model.parameters())}')

    if args.pretrain:
        print('---------- Before ----------')
        model.load_state_dict(torch.load(f'base/{args.dataset}-{args.model}.pt'))
        model.eval()

        y_soft = model(feats).exp()

        y_pred = y_soft.argmax(dim=-1, keepdim=True)
        valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
        test_acc = evaluate(y_pred, labels, test_idx, evaluator)
        print(f'Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}')

        print('---------- Correct & Smoothing ----------')
        cs = CorrectAndSmooth(num_correction_layers=args.num_correction_layers,
                              correction_alpha=args.correction_alpha,
                              correction_adj=args.correction_adj,
                              num_smoothing_layers=args.num_smoothing_layers,
                              smoothing_alpha=args.smoothing_alpha,
                              smoothing_adj=args.smoothing_adj,
                              autoscale=args.autoscale,
                              scale=args.scale)
        
        mask_idx = torch.cat([train_idx, valid_idx])
        y_soft = cs.correct(g, y_soft, labels[mask_idx], mask_idx)
        y_soft = cs.smooth(g, y_soft, labels[mask_idx], mask_idx)
        y_pred = y_soft.argmax(dim=-1, keepdim=True)
        valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)
        test_acc = evaluate(y_pred, labels, test_idx, evaluator)
        print(f'Valid acc: {valid_acc:.4f} | Test acc: {test_acc:.4f}')
    else:
        opt = optim.Adam(model.parameters(), lr=args.lr)

        best_acc = 0
        best_model = copy.deepcopy(model)

        # training
        print('---------- Training ----------')
        for i in range(args.epochs):

            model.train()
            opt.zero_grad()

            logits = model(feats)
            
            train_loss = F.nll_loss(logits[train_idx], labels.squeeze(1)[train_idx])
            train_loss.backward()

            opt.step()
            
            model.eval()
            with torch.no_grad():
                logits = model(feats)
                
                y_pred = logits.argmax(dim=-1, keepdim=True)

                train_acc = evaluate(y_pred, labels, train_idx, evaluator)
                valid_acc = evaluate(y_pred, labels, valid_idx, evaluator)

                print(f'Epoch {i} | Train loss: {train_loss.item():.4f} | Train acc: {train_acc:.4f} | Valid acc {valid_acc:.4f}')

                if valid_acc > best_acc:
                    best_acc = valid_acc
                    best_model = copy.deepcopy(model)
        
        # testing & saving model
        print('---------- Testing ----------')
        best_model.eval()
        
        logits = best_model(feats)
        
        y_pred = logits.argmax(dim=-1, keepdim=True)
        test_acc = evaluate(y_pred, labels, test_idx, evaluator)
        print(f'Test acc: {test_acc:.4f}')

        if not os.path.exists('base'):
            os.makedirs('base')

        torch.save(best_model.state_dict(), f'base/{args.dataset}-{args.model}.pt')
Пример #18
0
def main():
    num_epoches = 40
    batch_size = 256
    num_classes = 1
    embed_size = 16
    p = 0.1
    hidden_size = 64
    setup_seeds(RANDOM_STATE)

    # Data loading
    train_file = osp.join(INPUT_PATH, dataset, "train.csv")
    test_file = osp.join(INPUT_PATH, dataset, "test.csv")
    train_df = pd.read_csv(train_file)
    test_df = pd.read_csv(test_file)




    columns = train_df.columns
    stat_columns_file = osp.join(INPUT_PATH, "stat_columns.txt")
    category_columns_file = osp.join(INPUT_PATH, "category_columns.txt")
    stat_columns = pickle_load(stat_columns_file)
    category_columns = pickle_load(category_columns_file)[:-1]

    feature_columns =  stat_columns + category_columns
    normalized_columns = [stat_columns[-2]]
    except_normalized_columns = [column for column in feature_columns if column not in normalized_columns]
    print(f"category_columns: {category_columns}")
    print(f"normalized_columns: {normalized_columns}")
    print(f"except_normalized_columns: {except_normalized_columns}")

    standard_scaler = StandardScaler()
    standard_scaler.fit(train_df[normalized_columns].values)
    train_normalized = standard_scaler.transform(train_df[normalized_columns].values)
    test_normalized = standard_scaler.transform(test_df[normalized_columns].values)


    X_train = np.concatenate((train_normalized, train_df[except_normalized_columns].values), axis=1)
    y_train = train_df["target"].values
    X_test = np.concatenate((test_normalized, test_df[except_normalized_columns].values), axis=1)
    y_test = test_df["target"].values

    logger.write("x_train.shape: " + str(X_train.shape))
    logger.write("y_train.shape: " + str(y_train.shape))
    logger.write("x_test.shape: " + str(X_test.shape))


    n_features = len(feature_columns)
    n_category_features = len(category_columns)
    n_stat_features = len(stat_columns)

    embeds_desc = []
    X_total = np.concatenate((X_train, X_test), axis=0)
    for i in range(n_stat_features, n_features):
        cur_column = X_total[:, i]
        num_embed = int(max(cur_column) + 1)
        embeds_desc.append([num_embed, embed_size])

    # Train process
    logger.write("Training...")
    model = MLP(n_stat_features, hidden_size, embeds_desc, num_classes, p)
    model.to(device)
    if not osp.exists(MODEL_PATH):
        os.makedirs(MODEL_PATH)
    model_file = osp.join(MODEL_PATH, f"model_{dataset}.pt")
    patience = 6
    early_stopping = EarlyStopping(patience=patience, verbose=True, path=model_file)

    logger.write(model)

    train(X_train, y_train, model, n_stat_features, n_features, early_stopping, num_epoches, batch_size)

    # Predict process
    logger.write("Predicting...")
    stat_matrix_test = X_test[:, :n_stat_features]
    embeds_input_test = []
    for i in range(n_stat_features, n_features):
        embeds_input_test.append(X_test[:, i])
    stat_matrix_test = torch.tensor(stat_matrix_test, dtype=torch.float).to(device)
    embeds_input_test = torch.tensor(embeds_input_test, dtype=torch.long).to(device)
    test_data = [stat_matrix_test, embeds_input_test]

    model.load_state_dict(torch.load(model_file))
    y_test_prob = predict(test_data, model).cpu().numpy()

    # calc metrics

    logger.write(f"test max prob:{y_test_prob.max()}")
    logger.write(f"test min prob:{y_test_prob.min()}")


    logger.write("Metrics calculation...")
    recall_precision_score(y_test_prob, y_test)

    print("")
Пример #19
0
class Trainer():
    def __init__(self, config_path):
        config = configparser.ConfigParser()
        config.read(config_path)

        self.n_epoch = config.getint("general", "n_epoch")
        self.batch_size = config.getint("general", "batch_size")
        self.train_bert = config.getboolean("general", "train_bert")
        self.lr = config.getfloat("general", "lr")
        self.cut_frac = config.getfloat("general", "cut_frac")
        self.log_dir = Path(config.get("general", "log_dir"))
        if not self.log_dir.exists():
            self.log_dir.mkdir(parents=True)
        self.model_save_freq = config.getint("general", "model_save_freq")

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        # bert_config_path = config.get("bert", "config_path")
        # bert_tokenizer_path = config.get("bert", "tokenizer_path")
        # bert_model_path = config.get("bert", "model_path")

        self.bert_tokenizer = LongformerTokenizer.from_pretrained(
            'allenai/longformer-base-4096')
        # self.bert_tokenizer = BertTokenizer.from_pretrained(bert_tokenizer_path)
        tkzer_save_dir = self.log_dir / "tokenizer"
        if not tkzer_save_dir.exists():
            tkzer_save_dir.mkdir()
        self.bert_tokenizer.save_pretrained(tkzer_save_dir)
        self.bert_model = LongformerModel.from_pretrained(
            'allenai/longformer-base-4096')
        self.bert_config = self.bert_model.config
        # self.bert_config = BertConfig.from_pretrained(bert_config_path)
        # self.bert_model = BertModel.from_pretrained(bert_model_path, config=self.bert_config)
        self.max_seq_length = self.bert_config.max_position_embeddings - 2
        # self.max_seq_length = self.bert_config.max_position_embeddings
        self.bert_model.to(self.device)

        if self.train_bert:
            self.bert_model.train()
        else:
            self.bert_model.eval()

        train_conll_path = config.get("data", "train_path")
        print("train path", train_conll_path)
        assert Path(train_conll_path).exists()
        dev_conll_path = config.get("data", "dev_path")
        print("dev path", dev_conll_path)
        assert Path(dev_conll_path).exists()
        dev1_conll_path = Path(dev_conll_path) / "1"
        print("dev1 path", dev1_conll_path)
        assert dev1_conll_path.exists()
        dev2_conll_path = Path(dev_conll_path) / "2"
        print("dev2 path", dev2_conll_path)
        assert dev2_conll_path.exists()
        self.train_dataset = ConllDataset(train_conll_path)
        # self.dev_dataset = ConllDataset(dev_conll_path)
        self.dev1_dataset = ConllDataset(dev1_conll_path)
        self.dev2_dataset = ConllDataset(dev2_conll_path)
        if self.batch_size == -1:
            self.batch_size = len(self.train_dataset)

        self.scaler = torch.cuda.amp.GradScaler()
        tb_cmt = f"lr_{self.lr}_cut-frac_{self.cut_frac}"
        self.writer = SummaryWriter(log_dir=self.log_dir, comment=tb_cmt)

    def transforms(self, example, label_list):
        feature = convert_single_example(example, label_list,
                                         self.max_seq_length,
                                         self.bert_tokenizer)
        label_ids = feature.label_ids
        label_map = feature.label_map
        gold_labels = [-1] * self.max_seq_length
        # Get "Element" or "Main" token indices
        for i, lid in enumerate(label_ids):
            if lid == label_map['B-Element']:
                gold_labels[i] = 0
            elif lid == label_map['B-Main']:
                gold_labels[i] = 1
            elif lid in (label_map['I-Element'], label_map['I-Main']):
                gold_labels[i] = 2
            elif lid == label_map['X']:
                gold_labels[i] = 3
        # flush data to bert model
        input_ids = torch.tensor(feature.input_ids).unsqueeze(0).to(
            self.device)
        if self.train_bert:
            model_output = self.bert_model(input_ids)
        else:
            with torch.no_grad():
                model_output = self.bert_model(input_ids)

        # lstm (ignore padding parts)
        model_fv = model_output[0]
        input_ids = torch.tensor(feature.input_ids)
        label_ids = torch.tensor(feature.label_ids)
        gold_labels = torch.tensor(gold_labels)
        return model_fv, input_ids, label_ids, gold_labels

    @staticmethod
    def extract_tokens(fv, gold_labels):
        ents, golds = [], []
        ents_mask = [-1] * len(gold_labels)
        ent, gold, ent_id = [], None, 0
        ent_flag = False
        for i, gt in enumerate(gold_labels):
            if gt == 2:  # in case of "I-xxx"
                ent.append(fv[i, :])
                ents_mask[i] = ent_id
                ent_end = i
            elif gt == 3 and ent_flag:  # in case of "X"
                ent.append(fv[i, :])
                ents_mask[i] = ent_id
                ent_end = i
            elif ent:
                ents.append(ent)
                golds.append(gold)
                ent = []
                ent_id += 1
                ent_flag = False
            if gt in (0, 1):  # in case of "B-xxx"
                ent.append(fv[i, :])
                gold = gt
                ents_mask[i] = ent_id
                ent_start = i
                ent_flag = True
        else:
            if ent:
                ents.append(ent)
                golds.append(gold)
        return ents, golds, ents_mask

    def eval(self, dataset):
        tp, fp, tn, fn = 0, 0, 0, 0
        with torch.no_grad():
            for data in tqdm(dataset):
                # flush to Bert
                fname, example = data

                try:
                    fvs, input_ids, label_ids, gold_labels = self.transforms(
                        example, dataset.label_list)
                except RuntimeError:
                    print(f"{fname} cannot put in memory!")
                    continue

                # extract Element/Main tokens
                ents, ent_golds, _ = self.extract_tokens(
                    fvs.squeeze(0), gold_labels)

                for i, ent in enumerate(ents):
                    # convert to torch.tensor
                    inputs = torch.empty(
                        [len(ent),
                         self.bert_config.hidden_size]).to(self.device)
                    for j, token in enumerate(ent):
                        inputs[j, :] = token
                    target = ent_golds[i]
                    inputs = torch.mean(inputs, dim=0, keepdim=True)

                    # classification
                    outputs = self.mlp(inputs)
                    if target == 1:
                        if outputs < 0.5:
                            fn += 1
                        else:
                            tp += 1
                    else:
                        if outputs < 0.5:
                            tn += 1
                        else:
                            fp += 1

        return Score(tp, fp, tn, fn).calc_score()

    def train(self):
        # MLP
        self.mlp = MLP(self.bert_config.hidden_size)
        self.mlp.to(self.device)
        self.mlp.train()
        # learnging parameter settings
        params = list(self.mlp.parameters())
        if self.train_bert:
            params += list(self.bert_model.parameters())
        # loss
        self.criterion = BCEWithLogitsLoss()
        # optimizer
        self.optimizer = AdamW(params, lr=self.lr)
        num_train_steps = int(self.n_epoch * len(self.train_dataset) /
                              self.batch_size)
        num_warmup_steps = int(self.cut_frac * num_train_steps)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps, num_train_steps)

        try:
            best_dev1_f1, best_dev2_f1 = 0, 0
            # best_dev_f1 = 0
            itr = 1
            for epoch in range(1, self.n_epoch + 1):
                print("Epoch : {}".format(epoch))
                print("training...")
                for i in tqdm(
                        range(0, len(self.train_dataset), self.batch_size)):
                    # fvs, ents, batch_samples, inputs, outputs = None, None, None, None, None
                    itr += i
                    # create batch samples
                    if (i + self.batch_size) < len(self.train_dataset):
                        end_i = (i + self.batch_size)
                    else:
                        end_i = len(self.train_dataset)

                    batch_samples, batch_golds = [], []

                    for j in range(i, end_i):
                        # flush to Bert
                        fname, example = self.train_dataset[j]

                        fvs, input_ids, label_ids, gold_labels = self.transforms(
                            example, self.train_dataset.label_list)

                        # extract Element/Main tokens
                        ents, ent_golds, _ = self.extract_tokens(
                            fvs.squeeze(0), gold_labels)
                        for e in ents:
                            ent = torch.empty(
                                [len(e),
                                 self.bert_config.hidden_size]).to(self.device)
                            for k, t in enumerate(e):
                                ent[k, :] = t
                            batch_samples.append(torch.mean(ent, dim=0))
                        batch_golds.extend(ent_golds)

                    # convert to torch.tensor
                    inputs = torch.empty(
                        [len(batch_samples),
                         self.bert_config.hidden_size]).to(self.device)
                    for j, t in enumerate(batch_samples):
                        inputs[j, :] = t
                    targets = torch.tensor(batch_golds,
                                           dtype=torch.float).unsqueeze(1)

                    self.optimizer.zero_grad()
                    with torch.cuda.amp.autocast():
                        outputs = self.mlp(inputs)
                        loss = self.criterion(outputs, targets.to(self.device))
                        # loss = loss / 100
                    self.scaler.scale(loss).backward()
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                    self.scheduler.step()

                    del fvs, ents, batch_samples, inputs, outputs
                    torch.cuda.empty_cache()

                    # write to SummaryWriter
                    self.writer.add_scalar("loss", loss.item(), itr)
                    self.writer.add_scalar(
                        "lr", self.optimizer.param_groups[0]["lr"], itr)

                # write to SummaryWriter
                if self.train_bert:
                    self.bert_model.eval()
                self.mlp.eval()
                # import pdb; pdb.set_trace()

                print("train data evaluation...")
                tr_acc, tr_rec, _, tr_prec, tr_f1 = self.eval(
                    self.train_dataset)
                print(
                    f"acc: {tr_acc}, rec: {tr_rec}, prec: {tr_prec}, f1: {tr_f1}"
                )
                self.writer.add_scalar("train/acc", tr_acc, epoch)
                self.writer.add_scalar("train/rec", tr_rec, epoch)
                self.writer.add_scalar("train/prec", tr_prec, epoch)
                self.writer.add_scalar("train/f1", tr_f1, epoch)
                # print("dev data evaluation...")
                # dev_acc, dev_rec, _, dev_prec, dev_f1 = self.eval(self.dev_dataset)
                # print(f"acc: {dev_acc}, rec: {dev_rec}, prec: {dev_prec}, f1: {dev_f1}")
                # self.writer.add_scalar("dev/acc", dev_acc, epoch)
                # self.writer.add_scalar("dev/rec", dev_rec, epoch)
                # self.writer.add_scalar("dev/prec", dev_prec, epoch)
                # self.writer.add_scalar("dev/f1", dev_f1, epoch)
                # self.writer.flush()
                print("dev1 data evaluation...")
                dev1_acc, dev1_rec, _, dev1_prec, dev1_f1 = self.eval(
                    self.dev1_dataset)
                print(
                    f"acc: {dev1_acc}, rec: {dev1_rec}, prec: {dev1_prec}, f1: {dev1_f1}"
                )
                self.writer.add_scalar("dev1/acc", dev1_acc, epoch)
                self.writer.add_scalar("dev1/rec", dev1_rec, epoch)
                self.writer.add_scalar("dev1/prec", dev1_prec, epoch)
                self.writer.add_scalar("dev1/f1", dev1_f1, epoch)
                self.writer.flush()
                print("dev2 data evaluation...")
                dev2_acc, dev2_rec, _, dev2_prec, dev2_f1 = self.eval(
                    self.dev2_dataset)
                print(
                    f"acc: {dev2_acc}, rec: {dev2_rec}, prec: {dev2_prec}, f1: {dev2_f1}"
                )
                self.writer.add_scalar("dev2/acc", dev2_acc, epoch)
                self.writer.add_scalar("dev2/rec", dev2_rec, epoch)
                self.writer.add_scalar("dev2/prec", dev2_prec, epoch)
                self.writer.add_scalar("dev2/f1", dev2_f1, epoch)
                self.writer.flush()
                if self.train_bert:
                    self.bert_model.train()
                self.mlp.train()

                if epoch % self.model_save_freq == 0:
                    curr_log_dir = self.log_dir / f"epoch_{epoch}"
                    if not curr_log_dir.exists():
                        curr_log_dir.mkdir()
                    if self.train_bert:
                        self.bert_model.save_pretrained(curr_log_dir)
                    torch.save(self.mlp.state_dict(),
                               curr_log_dir / "mlp.model")

                # if best_dev_f1 <= dev_f1:
                #     best_dev_f1 = dev_f1
                #     best_dev_epoch = epoch
                #     if self.train_bert:
                #         best_dev_model = copy.deepcopy(self.bert_model)
                #     best_dev_mlp = copy.deepcopy(self.mlp.state_dict())
                if best_dev1_f1 <= dev1_f1:
                    best_dev1_f1 = dev1_f1
                    best_dev1_epoch = epoch
                    if self.train_bert:
                        best_dev1_model = copy.deepcopy(self.bert_model).cpu()
                    best_dev1_mlp = copy.deepcopy(self.mlp).cpu().state_dict()
                if best_dev2_f1 <= dev2_f1:
                    best_dev2_f1 = dev2_f1
                    best_dev2_epoch = epoch
                    if self.train_bert:
                        best_dev2_model = copy.deepcopy(self.bert_model).cpu()
                    best_dev2_mlp = copy.deepcopy(self.mlp).cpu().state_dict()

        except KeyboardInterrupt:
            # del fvs, ents, batch_samples, inputs, outputs
            # print(f"Best epoch was #{best_dev_epoch}!\nSave params...")
            # save_dev_dir = Path(self.log_dir) / "best"
            # if not save_dev_dir.exists():
            #     save_dev_dir.mkdir()
            # if self.train_bert:
            #     best_dev_model.save_pretrained(save_dev_dir)
            # torch.save(best_dev_mlp, save_dev_dir / "mlp.model")
            # print("Training was successfully finished!")
            print(
                f"Best epoch was dev1: #{best_dev1_epoch}, dev2: #{best_dev2_epoch}!\nSave params..."
            )
            save_dev1_dir = Path(self.log_dir) / "dev1_best"
            if not save_dev1_dir.exists():
                save_dev1_dir.mkdir()
            save_dev2_dir = Path(self.log_dir) / "dev2_best"
            if not save_dev2_dir.exists():
                save_dev2_dir.mkdir()
            if self.train_bert:
                best_dev1_model.save_pretrained(save_dev1_dir)
                best_dev2_model.save_pretrained(save_dev2_dir)
            torch.save(best_dev1_mlp, save_dev1_dir / "mlp.model")
            torch.save(best_dev2_mlp, save_dev2_dir / "mlp.model")
            print("Training was successfully finished!")
            raise KeyboardInterrupt
        else:
            # print(f"Best epoch was #{best_dev_epoch}!\nSave params...")
            # save_dev_dir = Path(self.log_dir) / "best"
            # if not save_dev_dir.exists():
            #     save_dev_dir.mkdir()
            # if self.train_bert:
            #     best_dev_model.save_pretrained(save_dev_dir)
            # torch.save(best_dev_mlp, save_dev_dir / "mlp.model")
            # print("Training was successfully finished!")
            print(
                f"Best epoch was dev1: #{best_dev1_epoch}, dev2: #{best_dev2_epoch}!\nSave params..."
            )
            save_dev1_dir = Path(self.log_dir) / "dev1_best"
            if not save_dev1_dir.exists():
                save_dev1_dir.mkdir()
            save_dev2_dir = Path(self.log_dir) / "dev2_best"
            if not save_dev2_dir.exists():
                save_dev2_dir.mkdir()
            if self.train_bert:
                best_dev1_model.save_pretrained(save_dev1_dir)
                best_dev2_model.save_pretrained(save_dev2_dir)
            torch.save(best_dev1_mlp, save_dev1_dir / "mlp.model")
            torch.save(best_dev2_mlp, save_dev2_dir / "mlp.model")
            print("Training was successfully finished!")
            sys.exit()
Пример #20
0
def train(args):

    perturb_mock, sgRNA_list_mock = makedata.json_to_perturb_data(path = "/home/member/xywang/WORKSPACE/MaryGUO/one-shot/MOCK_MON_crispr_combine/crispr_analysis")

    total = sc.read_h5ad("/home/member/xywang/WORKSPACE/MaryGUO/one-shot/mock_one_perturbed.h5ad")
    trainset, testset = preprocessing.make_total_data(total,sgRNA_list_mock)

    TrainSet = perturbdataloader(trainset, ways = args.num_ways, support_shots = args.num_shots, query_shots = 15)
    TrainLoader = DataLoader(TrainSet, batch_size=args.batch_size_train, shuffle=False,num_workers=args.num_workers)

    model = MLP(out_features = args.num_ways)

    model.to(device=args.device)
    model.train()
    meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Training loop
    with tqdm(TrainLoader, total=args.num_batches) as pbar:
        for batch_idx, (inputs_support, inputs_query, target_support, target_query) in enumerate(pbar):
            model.zero_grad()

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            inputs_query = inputs_query.to(device=args.device)
            target_query = target_query.to(device=args.device)

            outer_loss = torch.tensor(0., device=args.device)
            accuracy = torch.tensor(0., device=args.device)
            for task_idx, (train_input, train_target, test_input,
                           test_target) in enumerate(zip(inputs_support, target_support,inputs_query, target_query)):

                train_logit = model(train_input)
                inner_loss = F.cross_entropy(train_logit, train_target)

                model.zero_grad()
                params = gradient_update_parameters(model,
                                                    inner_loss,
                                                    step_size=args.step_size,
                                                    first_order=args.first_order)

                test_logit = model(test_input, params=params)
                outer_loss += F.cross_entropy(test_logit, test_target)

                with torch.no_grad():
                    accuracy += get_accuracy(test_logit, test_target)

            outer_loss.div_(args.batch_size_train)
            accuracy.div_(args.batch_size_train)

            outer_loss.backward()
            meta_optimizer.step()

            pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
            if batch_idx >= args.num_batches or accuracy.item() > 0.95:
                break

    # Save model
    if args.output_folder is not None:
        filename = os.path.join(args.output_folder, 'maml_omniglot_'
                                                    '{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
        with open(filename, 'wb') as f:
            state_dict = model.state_dict()
            torch.save(state_dict, f)

    # start test
    test_support, test_query, test_target_support, test_target_query \
        = helpfuntions.sample_once(testset,support_shot=args.num_shots, shuffle=False,plus = len(trainset))
    test_query = torch.from_numpy(test_query).to(device=args.device)
    test_target_query = torch.from_numpy(test_target_query).to(device=args.device)

    TrainSet = perturbdataloader_test(test_support, test_target_support)
    TrainLoader = DataLoader(TrainSet, args.batch_size_test)

    meta_optimizer.zero_grad()
    inner_losses = []
    accuracy_test = []

    for epoch in range(args.num_epoch):
        model.to(device=args.device)
        model.train()

        for _, (inputs_support,target_support) in enumerate(TrainLoader):

            inputs_support = inputs_support.to(device=args.device)
            target_support = target_support.to(device=args.device)

            train_logit = model(inputs_support)
            loss = F.cross_entropy(train_logit, target_support)
            inner_losses.append(loss)
            loss.backward()
            meta_optimizer.step()
            meta_optimizer.zero_grad()

            test_logit = model(test_query)
            with torch.no_grad():
                accuracy = get_accuracy(test_logit, test_target_query)
                accuracy_test.append(accuracy)



        if (epoch + 1) % 3 == 0:
            print('Epoch [{}/{}], Loss: {:.4f},accuray: {:.4f}'.format(epoch + 1, args.num_epoch, loss,accuracy))