Beispiel #1
0
def train():
    pretrain, code_id, word_id = source_prepare()
    print('%d different words, %d different codes' % (len(word_id), len(code_id)), flush=True)

    criterion = nn.CrossEntropyLoss()

    Acc = 0.
    for fold in range(5):
        Net = TransformerNet(Pretrain_type, pretrain, Max_seq_len, Embedding_size, Inner_hid_size, len(code_id), D_k,
                             D_v, dropout_ratio=Dropout, num_layers=Num_layers, num_head=Num_head, Freeze=Freeze_emb).cuda()
        optimizer = optim.Adam(Net.parameters(), lr=Learning_rate, eps=1e-08, weight_decay=Weight_decay)

        # train_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.train.txt'
        # val_file = DATA_path + 'AskAPatient/AskAPatient.fold-' + str(fold) + '.validation.txt'
        train_file = DATA_path + 'trainsplit_' + str(fold) + '.csv'
        val_file = DATA_path + 'val_' + str(fold) + '.csv'
        train_data = tokenizer(word_id, code_id, train_file, pretrain_type=Pretrain_type)
        val_data = tokenizer(word_id, code_id, val_file, pretrain_type=Pretrain_type)

        print('Fold %d: %d training data, %d validation data' % (fold, len(train_data.data), len(val_data.data)))
        print('max length: %d %d' % (train_data.max_length, val_data.max_length), flush=True)

        for e in range(Epoch):
            train_data.reset_epoch()
            Net.train()
            while not train_data.epoch_finish:
                optimizer.zero_grad()
                seq, label, seq_length, mask, seq_pos, standard_emb = train_data.get_batch(Batch_size)
                results = Net(seq, seq_pos, standard_emb)
                loss = criterion(results, label)
                loss.backward()
                optimizer.step()

            if (e + 1) % Val_every == 0:
                Net.eval()

                train_data.reset_epoch()
                train_correct = 0
                i = 0
                while not train_data.epoch_finish:
                    seq, label, seq_length, mask, seq_pos, standard_emb = train_data.get_batch(Batch_size)
                    results = Net(seq, seq_pos, standard_emb)
                    _, idx = results.max(1)
                    train_correct += len((idx == label).nonzero())
                    i += Batch_size
                # assert i == len(train_data.data)
                train_accuracy = float(train_correct) / float(i)

                val_data.reset_epoch()
                val_correct = 0
                i = 0
                while not val_data.epoch_finish:
                    seq, label, seq_length, mask, seq_pos, standard_emb = val_data.get_batch(Batch_size)
                    results = Net(seq, seq_pos, standard_emb)
                    _, idx = results.max(1)
                    val_correct += len((idx == label).nonzero())
                    i += Batch_size
                # assert i == len(val_data.data)
                val_accuracy = float(val_correct) / float(len(val_data.data))

                print('[fold %d epoch %d] training loss: %.4f, % d correct, %.4f accuracy;'
                      ' validation: %d correct, %.4f accuracy' %
                      (fold, e, loss.item(), train_correct, train_accuracy, val_correct, val_accuracy), flush=True)

                torch.save(Net.state_dict(), SAVE_DIR + 'data1_false_biobert_' + str(fold) + '_' + str(e))

            if (e + 1) % LR_decay_epoch == 0:
                adjust_learning_rate(optimizer, LR_decay)
                print('learning rate decay!', flush=True)

        Acc += val_accuracy

        del train_data, val_data
        gc.collect()

    print('finial validation accuracy: %.4f' % (Acc / 5))
def train(args):
    device = torch.device("cpu")

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    # load dataset
    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    # load style image
    style_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)

    # load network
    transformer = TransformerNet().to(device)
    vgg = Vgg16(requires_grad=False).to(device)

    # define optimizer and loss function
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    features_style = vgg(utils.normalize_batch(style))
    gram_style = [utils.gram_matrix(y) for y in features_style]

    for e in range(args.epochs):
        transformer.train()
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            count += len(x)
            optimizer.zero_grad()

            image_original = x.to(device)
            image_transformed = transformer(x)

            image_original = utils.normalize_batch(image_original)
            image_transformed = utils.normalize_batch(image_transformed)

            # extract features for compute content loss
            features_original= vgg(image_original)
            features_transformed = vgg(image_transformed)
            content_loss = args.content_weight * mse_loss(features_transformed.relu3_3, features_original.relu3_3)

             # extract features for compute style loss
            style_loss = 0.
            for ft_y, gm_s in zip(features_transformed, gram_style):
                gm_y = utils.gram_matrix(ft_y)
                style_loss += mse_loss(gm_y, gm_s[:len(x), :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()

            if (batch_id + 1) % 200 == 0:
                print("Epoch {}:[{}/{}]".format(e + 1, count, len(train_dataset)))

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)
Beispiel #3
0
def train(content_img_name=None, style_img_name=None, features=None):
    transformer = TransformerNet()
    # features = Vgg16()

    lr = 0.001
    weight_content = 1e5
    weight_style = 1e10
    optimizer = torch.optim.Adam(transformer.parameters(), lr)
    mse_loss = torch.nn.MSELoss()

    style = load_image(style_img_name)
    style = style_transform(style)
    style = style.unsqueeze(0)
    style_v = Variable(style)
    style_v = normalize_batch(style_v)
    features_style = features(style_v)
    gram_style = [gram_matrix(y) for y in features_style]

    transformer.train()
    x = load_image(content_img_name)
    x = content_transform(x)
    x = x.unsqueeze(0)
    x = Variable(x)

    if use_cuda:
        transformer.cuda()
        features.cuda()
        x = x.cuda()
        gram_style = [gram.cuda() for gram in gram_style]

    # training
    count = 0
    log_name = './logs/log_exp_{}.txt'.format(exp_num)
    log = []
    while count < iteration_total:
        optimizer.zero_grad()

        y = transformer(x)

        y = normalize_batch(y)
        x = normalize_batch(x)

        features_y = features(y)
        features_x = features(x)

        loss_content = mse_loss(features_y[1], features_x[1])

        loss_style = 0.
        for ft_y, gm_s in zip(features_y, gram_style):
            gm_y = gram_matrix(ft_y)
            loss_style = loss_style + mse_loss(gm_y, gm_s)

        total_loss = weight_content * loss_content + weight_style * loss_style
        total_loss.backward()
        optimizer.step()

        # log show
        count += 1
        msg = '{}\titeration: {}\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}\n'.format(
            time.ctime(), count, loss_content.item(), loss_style.item(),
            total_loss.item())
        log.append(msg)
        if count % 50 == 0:
            print(''.join(log))
            with open(log_name, 'a') as f:
                f.writelines(''.join(log))
                log.clear()

    # save model
    transformer.eval()
    if use_cuda:
        transformer.cpu()
    save_model_name = './models/model_exp_{}.pt'.format(exp_num)
    torch.save(transformer.state_dict(), save_model_name)
Beispiel #4
0
def train(args):
    device = torch.device("cuda" if args.cuda else "cpu")
    if args.backbone == "vgg":
        content_layer = ['relu_4']
        style_layer = ['relu_2', 'relu_4', 'relu_7', 'relu_10']
    elif args.backbone == "resnet":
        content_layer = ["conv_3"]
        style_layer = ["conv_1", "conv_2", "conv_3", "conv_4"]

    total_layer = list(dict.fromkeys(content_layer + style_layer))

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.Resize(args.image_size),
        transforms.CenterCrop(args.image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda x: x.mul(255))
    ])
    train_dataset = datasets.ImageFolder(args.dataset, transform)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size)

    transformer = TransformerNet().to(device)
    optimizer = Adam(transformer.parameters(), args.lr)
    mse_loss = torch.nn.MSELoss()

    if args.backbone == "vgg":
        loss_model = vgg16().eval().to(device)
    elif args.backbone == "resnet":
        loss_model = resnet().eval().to(device)

    style_transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Lambda(lambda x: x.mul(255))])
    style = utils.load_image(args.style_image, size=args.style_size)
    style = style_transform(style)
    style = style.repeat(args.batch_size, 1, 1, 1).to(device)
    feature_style = loss_model(utils.normalize_batch(style), style_layer)
    gram_style = {
        key: utils.gram_matrix(val)
        for key, val in feature_style.items()
    }

    for e in range(args.epochs):
        transformer.train()
        agg_content_loss = 0.
        agg_style_loss = 0.
        count = 0
        for batch_id, (x, _) in enumerate(train_loader):
            n_batch = len(x)
            count += n_batch
            optimizer.zero_grad()

            x = x.to(device)
            y = transformer(x)

            y = utils.normalize_batch(y)
            x = utils.normalize_batch(x)

            feature_y = loss_model(y, total_layer)
            feature_x = loss_model(x, content_layer)

            content_loss = 0.
            for layer in content_layer:
                content_loss += args.content_weight * mse_loss(
                    feature_y[layer], feature_x[layer])

            style_loss = 0.
            for name in style_layer:
                gm_y = utils.gram_matrix(feature_y[name])
                style_loss += mse_loss(gm_y, gram_style[name][:n_batch, :, :])
            style_loss *= args.style_weight

            total_loss = content_loss + style_loss
            total_loss.backward()
            optimizer.step()
            agg_content_loss += content_loss.item()
            agg_style_loss += style_loss.item()

            if (batch_id + 1) % args.log_interval == 0:
                mesg = "{}\tEpoch {}:\t[{}/{}]\tcontent: {:.6f}\tstyle: {:.6f}\ttotal: {:.6f}".format(
                    time.ctime(), e + 1, count, len(train_dataset),
                    agg_content_loss / (batch_id + 1),
                    agg_style_loss / (batch_id + 1),
                    (agg_content_loss + agg_style_loss) / (batch_id + 1))
                print(mesg)

            if args.checkpoint_model_dir is not None and (
                    batch_id + 1) % args.checkpoint_interval == 0:
                transformer.eval().cpu()
                ckpt_model_filename = "ckpt_epoch_" + str(
                    e) + "_batch_id_" + str(batch_id + 1) + ".pth"
                ckpt_model_path = os.path.join(args.checkpoint_model_dir,
                                               ckpt_model_filename)
                torch.save(transformer.state_dict(), ckpt_model_path)
                transformer.to(device).train()

    # save model
    transformer.eval().cpu()
    save_model_filename = "epoch_" + str(args.epochs) + "_" + str(
        time.ctime()).replace(' ', '_') + "_" + str(
            args.content_weight) + "_" + str(args.style_weight) + ".model"
    save_model_path = os.path.join(args.save_model_dir, save_model_filename)
    torch.save(transformer.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)