Exemplo n.º 1
0
def train_classifier(device, args):
    encoder = FeatureExtractor()
    encoder.load_state_dict(torch.load(args.encoder_path))
    encoder.eval()
    classifier = Classifier(encoder)
    classifier.to(device)
    all_chunks = []
    all_labels = []
    for label in filesystem.listdir_complete(filesystem.train_audio_chunks_dir):
        chunks = filesystem.listdir_complete(label)
        all_chunks = all_chunks + chunks
        all_labels = all_labels + [label.split('/')[-1]] * len(chunks)
    train_chunks, eval_chunks, train_labels, eval_labels = train_test_split(all_chunks, all_labels, test_size=args.eval_size)

    # transforms and dataset
    trf = normalize
    # dataset generation
    labels_encoder = LabelsEncoder(pd.read_csv(filesystem.labels_encoding_file))
    train_dataset = DiscriminativeDataset(train_chunks, train_labels, labels_encoder, transforms=trf)
    eval_dataset = DiscriminativeDataset(eval_chunks, eval_labels, labels_encoder, transforms=trf)
    train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True,
                                num_workers=4, collate_fn=None,pin_memory=True)
    eval_dataloader = DataLoader(eval_dataset, batch_size=1, shuffle=True,
                                num_workers=4, collate_fn=None,pin_memory=True)

    optimizer = optim.Adam(classifier.parameters(), lr=args.lr)
    loss_criterion = nn.CrossEntropyLoss()
    train_count = 0
    eval_count = 0
    for epoch in range(args.n_epochs):
        print('Epoch:', epoch, '/', args.n_epochs)
        train_count = train_step_classification(classifier, train_dataloader, optimizer, loss_criterion, args.verbose_epochs, device, train_count)
        torch.save(classifier.state_dict(), os.path.join(wandb.run.dir, 'model_checkpoint.pt'))
        eval_count = eval_step_classification(classifier, eval_dataloader, loss_criterion, args.verbose_epochs, device, eval_count)
def main():
    # define parameters
    num_class = 6
    batch_size = 1
    time_step = 32
    cnn_feat_size = 256     # AlexNet
    gaze_size = 3
    gaze_lstm_hidden_size = 64
    gaze_lstm_projected_size = 128
    # dataset_path = '../data/gaze_dataset'
    dataset_path = '../../gaze-net/gaze_dataset'
    img_size = (224, 224)
    time_skip = 2

    # define model
    arch = 'alexnet'
    extractor_model = FeatureExtractor(arch=arch)
    extractor_model.features = torch.nn.DataParallel(extractor_model.features)
    extractor_model.cuda()      # uncomment this line if using cpu
    extractor_model.eval()

    model = SpatialAttentionModel(num_class, cnn_feat_size,
                        gaze_size, gaze_lstm_hidden_size, gaze_lstm_projected_size)
    model.cuda()

    # load model from checkpoint
    model = load_checkpoint(model)


    trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2)
    train_data = trainGenerator.flow_from_directory(dataset_path, subset='training', crop=False,
                    batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch',
                    time_skip=time_skip)
    # small dataset, error using validation split
    val_data = trainGenerator.flow_from_directory(dataset_path, subset='validation', crop=False,
                batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch',
                time_skip=time_skip)

    # start predict
    for i in range(10):
        print("start a new interaction")
        # img_seq: (ts,224,224,3), gaze_seq: (ts, 3), ouput: (ts, 6)
        # [img_seq, gaze_seq], target = next(val_data)
        [img_seq, gaze_seq], target = next(train_data)
        restart = True

        predict(img_seq, gaze_seq, extractor_model, model, restart=restart)
        print(target)
        for j in range(img_seq.shape[0]):
            # predict(img_seq[j], gaze_seq[j], None, model, restart=restart)
            # print(target[j])
            # restart = False
            img = img_seq[j,:,:,:]
            gazes = gaze_seq
            cv2.circle(img, (int(gazes[j,1]), int(gazes[j,2])), 10, (255,0,0),-1)
            cv2.imshow('ImageWindow', img)
            cv2.waitKey(33)
Exemplo n.º 3
0
def main():
    parser = get_parser()
    args = parser.parse_args()
    model_path = args.model
    input_path = args.input
    sound_path = args.output
    model = FeatureExtractor()
    model.load_state_dict(torch.load(model_path))
    device = torch.device('cuda')
    cpu_device = torch.device('cpu')
    model.to(device)
    #data = normalize(torchaudio.load(input_path)[0][0].reshape(1, -1))
    data = torch.from_numpy(normalize(torch.randn(1,
                                                  132480))).float().to(device)
    data = data.reshape(1, 1, -1)
    model.eval()
    sound = model(data)
    print(functional.mse_loss(sound, data).item())
    sound = sound.to(cpu_device)
    torchaudio.save(sound_path, sound.reshape(-1), 44100)
Exemplo n.º 4
0
feature_extractor.load_state_dict(
    torch.load(r'saved_model\feature_extractor_CE_8_subjects.pkl'))
domain_classifier.load_state_dict(
    torch.load(r'saved_model\domain_classifier_CE_8_subjects.pkl'))
label_predictor.load_state_dict(
    torch.load(r'saved_model\label_predictor_CE_8_subjects.pkl'))

# ------------------------------------------ Testing Stage -------------------------------------------------------- #

window_size = 52
stride = 1
max_fit = 30
jump = 1
threshold = 60 / 128

feature_extractor.eval()
label_predictor.eval()
domain_classifier.eval()
print('Start Testing...')
# 遍历11个subject
for subject_index in range(len(total_user_data)):
    # subject_index = 6
    start_time = time.time()
    user_data = total_user_data[subject_index]
    user_labels = total_user_labels[subject_index]

    iter_times = []
    iter_activity_times = []
    label_prediction = []
    label_truth = []
Exemplo n.º 5
0
def train(model, train_loader, valid_loader, optimizer, criterion, args):
    # declare content loss
    best_err = None
    feature_extractor = FeatureExtractor().cuda()
    feature_extractor.eval()

    # load data
    model_path = f'fsrcnn_{args.scale}x.pt'
    checkpoint = {'epoch': 1}   # start from 1

    # load model from exist .pt file
    if args.load is True and os.path.isfile(model_path):
        r"""
        load a pickle file from exist parameter

        state_dict: model's state dict
        epoch: parameters were updated in which epoch
        """
        checkpoint = torch.load(model_path, map_location=f'cuda:{args.gpu_id}')
        checkpoint['epoch'] += 1    # start from next epoch
        model.load_state_dict(checkpoint['state_dict'])

    # store the training time
    writer = writer_builder(args.log_path,args.model_name)

    for epoch in range(checkpoint['epoch'], args.epochs+1):
        model.train()
        err = 0.0
        valid_err = 0.0

        store_data_cnt = 0  # to create new dataset

        for data in tqdm(train_loader, desc=f'train epoch: {epoch}/{args.epochs}'):
            # read data from data loader
            inputs, target, stroke_num = data
            inputs, target = inputs.cuda(), target.cuda()

            # predicted fixed 6 axis data
            pred = model(inputs)

            # inverse transform
            pred = inverse_scaler_transform(pred, inputs)

            # MSE loss
            mse_loss = criterion(pred, target)

            # content loss
            gen_features = feature_extractor(pred)
            real_features = feature_extractor(target)
            content_loss = criterion(gen_features, real_features)

            # for compatible but bad for memory usage
            loss = mse_loss + content_loss

            err += loss.sum().item() * inputs.size(0)

            # out2csv every check interval epochs (default: 5)
            if epoch % args.check_interval == 0:
                out2csv(inputs, f'{epoch}_input', args.stroke_length)
                out2csv(pred, f'{epoch}_output', args.stroke_length)
                out2csv(target, f'{epoch}_target', args.stroke_length)

            if epoch  == args.epochs:
                if not os.path.exists('final_output'):
                    os.mkdir('final_output')
                save_final_predict_and_new_dataset(pred, stroke_num, f'final_output/', args, store_data_cnt)
                store_data_cnt+=args.batch_size
            
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # cross validation
        model.eval()
        with torch.no_grad():
            for data in tqdm(valid_loader, desc=f'valid epoch: {epoch}/{args.epochs}'):

                inputs, target = data
                inputs, target = inputs.cuda(), target.cuda()

                pred = model(inputs)

                # inverse transform
                pred = inverse_scaler_transform(pred, inputs)

                # MSE loss
                mse_loss = criterion(pred, target)

                # content loss
                gen_features = feature_extractor(pred)
                real_features = feature_extractor(target)
                content_loss = criterion(gen_features, real_features)

                # for compatible
                loss = mse_loss + content_loss

                valid_err += loss.sum().item() * inputs.size(0)

        if epoch  == args.epochs:
            save_final_predict_and_new_dataset(pred, stroke_num, f'final_output/', args, store_data_cnt)
            store_data_cnt+=args.batch_size

        # compute loss
        err /= len(train_loader.dataset)
        valid_err /= len(valid_loader.dataset)
        print(f'train loss: {err:.4f}, valid loss: {valid_err:.4f}')

        # update every epoch
        # save model as pickle file
        if best_err is None or err < best_err:
            best_err = err

            # save current epoch and model parameters
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch,
                }
                , model_path)

        # update loggers
        writer.add_scalars('Loss/', {'train loss': err,
                                          'valid loss': valid_err}, epoch)

    writer.close()
Exemplo n.º 6
0
def train(model, train_loader, valid_loader, optimizer, criterion, args):
    # content_loss
    best_err = None
    feature_extractor = FeatureExtractor().cuda()
    feature_extractor.eval()

    writer, log_path = writer_builder(args.log_path,
                                      args.model_name,
                                      load=args.load)

    # init data
    checkpoint = {
        'epoch': 1,  # start from 1
        'train_iter': 0,  # train iteration
        'valid_iter': 0,  # valid iteration
    }
    model_path = os.path.join(log_path, f'{args.model_name}_{args.scale}x.pt')

    # config
    model_config(train_args,
                 save=log_path)  # save model configuration before training

    # load model from exist .pt file
    if args.load and os.path.isfile(model_path):
        r"""
        load a pickle file from exist parameter

        state_dict: model's state dict
        epoch: parameters were updated in which epoch
        """
        checkpoint = torch.load(model_path, map_location=f'cuda:{args.gpu_id}')
        checkpoint['epoch'] += 1  # start from next epoch
        checkpoint['train_iter'] += 1
        checkpoint['valid_iter'] += 1
        model.load_state_dict(checkpoint['state_dict'])

    # initialize the early_stopping object
    if args.early_stop:
        early_stopping = EarlyStopping(patience=args.patience,
                                       threshold=args.threshold,
                                       verbose=args.verbose,
                                       path=model_path)

    if args.scheduler:
        scheduler = schedule_builder(optimizer, args.scheduler, args.step,
                                     args.factor)

    # progress bar postfix value
    pbar_postfix = {
        'MSE loss': 0.0,
        'Content loss': 0.0,
        'lr': args.lr,
    }

    for epoch in range(checkpoint['epoch'], args.epochs + 1):
        model.train()
        err = 0.0
        valid_err = 0.0

        train_bar = tqdm(train_loader,
                         desc=f'Train epoch: {epoch}/{args.epochs}')
        for data in train_bar:
            # load data from data loader
            inputs, target, _ = data
            inputs, target = inputs.cuda(), target.cuda()

            # predicted fixed 6 axis data
            pred = model(inputs)

            # MSE loss
            mse_loss = args.alpha * criterion(pred - inputs, target - inputs)

            # content loss
            gen_features = feature_extractor(pred)
            real_features = feature_extractor(target)
            content_loss = args.beta * criterion(gen_features, real_features)

            # for compatible but bad for memory usage
            loss = mse_loss + content_loss

            # update progress bar
            pbar_postfix['MSE loss'] = mse_loss.item()
            pbar_postfix['Content loss'] = content_loss.item()

            # show current lr
            if args.scheduler:
                pbar_postfix['lr'] = optimizer.param_groups[0]['lr']

            train_bar.set_postfix(pbar_postfix)

            err += loss.sum().item() * inputs.size(0)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # update writer
            writer.add_scalar('Iteration/train loss',
                              loss.sum().item(), checkpoint['train_iter'])
            checkpoint['train_iter'] += 1

        # cross validation
        valid_bar = tqdm(valid_loader,
                         desc=f'Valid epoch:{epoch}/{args.epochs}',
                         leave=False)
        model.eval()
        input_epoch = pred_epoch = target_epoch = torch.empty(0, 0)
        with torch.no_grad():
            for data in valid_bar:
                # for data in valid_loader:
                inputs, target, _ = data
                inputs, target = inputs.cuda(), target.cuda()

                pred = model(inputs)

                # MSE loss
                mse_loss = criterion(pred - inputs, target - inputs)

                # content loss
                gen_features = feature_extractor(pred)
                real_features = feature_extractor(target)
                content_loss = criterion(gen_features, real_features)

                # for compatible
                loss = mse_loss + content_loss

                # update progress bar
                pbar_postfix['MSE loss'] = mse_loss.item()
                pbar_postfix['Content loss'] = content_loss.item()

                # show current lr
                if args.scheduler:
                    pbar_postfix['lr'] = optimizer.param_groups[0]['lr']

                valid_bar.set_postfix(pbar_postfix)

                valid_err += loss.sum().item() * inputs.size(0)

                # update writer
                writer.add_scalar('Iteration/valid loss',
                                  loss.sum().item(), checkpoint['valid_iter'])
                checkpoint['valid_iter'] += 1

                # out2csv every check interval epochs (default: 5)
                if epoch % args.check_interval == 0:
                    input_epoch = inputs
                    pred_epoch = pred
                    target_epoch = target

        # out2csv every check interval epochs (default: 5)
        if epoch % args.check_interval == 0:

            # tensor to csv file
            out2csv(input_epoch, f'{epoch}', 'input', args.out_num,
                    args.save_path, args.stroke_length)
            out2csv(pred_epoch, f'{epoch}', 'output', args.out_num,
                    args.save_path, args.stroke_length)
            out2csv(target_epoch, f'{epoch}', 'target', args.out_num,
                    args.save_path, args.stroke_length)

        # compute loss
        err /= len(train_loader.dataset)
        valid_err /= len(valid_loader.dataset)
        print(f'\ntrain loss: {err:.4f}, valid loss: {valid_err:.4f}')

        # update scheduler
        if args.scheduler:
            scheduler.step()

        # update loggers
        writer.add_scalars(
            'Epoch',
            {
                'train loss': err,
                'valid loss': valid_err
            },
            epoch,
        )

        # early_stopping needs the validation loss to check if it has decresed,
        # and if it has, it will make a checkpoint of the current model
        if args.early_stop:
            early_stopping(valid_err, model, epoch)

            if early_stopping.early_stop:
                print("Early stopping")
                break
        # if early stop is false, store model when the err is lowest
        elif epoch == checkpoint['epoch'] or err < best_err:
            best_err = err  # save err in first epoch

            # save current epoch and model parameters
            torch.save(
                {
                    'state_dict': model.state_dict(),
                    'epoch': epoch,
                    'train_iter': checkpoint['train_iter'],
                    'valid_iter': checkpoint['valid_iter'],
                }, model_path)

    writer.close()
def main():
    # define parameters
    TRAIN = True
    num_class = 6
    batch_size = 1
    # time_step = 32
    epochs = 50
    cnn_feat_size = 256     # AlexNet
    gaze_size = 3
    gaze_lstm_hidden_size = 64
    gaze_lstm_projected_size = 128
    learning_rate = 0.0001
    momentum = 0.9
    weight_decay = 1e-4
    eval_freq = 1       # epoch
    print_freq = 1      # iteration
    # dataset_path = '../data/gaze_dataset'
    dataset_path = '../../gaze-net/gaze_dataset'
    img_size = (224, 224)
    log_path = '../log'
    logger = Logger(log_path, 'spatial')

    # define model
    arch = 'alexnet'
    extractor_model = FeatureExtractor(arch=arch)
    extractor_model.features = torch.nn.DataParallel(extractor_model.features)
    extractor_model.cuda()      # uncomment this line if using cpu
    extractor_model.eval()

    model = SpatialAttentionModel(num_class, cnn_feat_size,
                        gaze_size, gaze_lstm_hidden_size, gaze_lstm_projected_size)
    model.cuda()

    # define loss and optimizer
    # criterion = nn.CrossEntropyLoss()
    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = torch.optim.SGD(model.parameters(), learning_rate,
                                momentum = momentum, weight_decay=weight_decay)

    # define generator
    trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2)
    train_data = trainGenerator.flow_from_directory(dataset_path, subset='training', crop=False,
                    batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch')
    # small dataset, error using validation split
    val_data = trainGenerator.flow_from_directory(dataset_path, subset='validation', crop=False,
                batch_size=batch_size, target_size= img_size, class_mode='sequence_pytorch')
    # val_data = train_data

    def test(train_data):
        [img_seq, gaze_seq], target = next(train_data)
        img = img_seq[100,:,:,:]
        img_gamma = adjust_contrast(img)
        imsave('contrast.jpg', img_gamma)
        imsave('original.jpg', img)

    # test(train_data)
    print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
    # img_seq: (ts,224,224,3), gaze_seq: (ts, 3), ouput: (ts, 6)
    # [img_seq, gaze_seq], output = next(train_data)
    # print("gaze data shape")
    # print(img_seq.shape)
    # print(gaze_seq.shape)
    # print(output.shape)

    # start Training
    para = {'bs': batch_size, 'img_size': img_size, 'num_class': num_class,
            'print_freq': print_freq}
    if TRAIN:
        print("get into training mode")
        best_acc = 0

        for epoch in range(epochs):
            adjust_learning_rate(optimizer, epoch, learning_rate)
            print('Epoch: {}'.format(epoch))
            # train for one epoch
            train(train_data, extractor_model, model, criterion, optimizer, epoch, logger, para)

            # evaluate on validation set
            if epoch % eval_freq == 0 or epoch == epochs - 1:
                acc = validate(val_data, extractor_model, model, criterion, epoch, logger, para, False)
                is_best = acc > best_acc
                best_acc = max(acc, best_acc)
                save_checkpoint({
                    'epoch': epoch + 1,
                    'arch': arch,
                    'state_dict': model.state_dict(),
                    'best_acc': best_acc,
                    'optimizer': optimizer.state_dict(),
                }, is_best)
    else:
        model = load_checkpoint(model)
        print("get into testing and visualization mode")
        print("visualization for training data")
        vis_data_path = '../vis/train/'
        if not os.path.exists(vis_data_path):
            os.makedirs(vis_data_path)
        acc = validate(train_data, extractor_model, model, criterion, -1, \
                        logger, para, False, vis_data_path)
        print("visualization for validation data")
        vis_data_path = '../vis/val/'
        if not os.path.exists(vis_data_path):
            os.makedirs(vis_data_path)
        acc = validate(val_data, extractor_model, model, criterion, -1, \
                        logger, para, True, vis_data_path)
def main():
    # define parameters
    TRAIN = True
    time_skip = 2
    num_class = 6
    batch_size = 1
    epochs = 50
    cnn_feat_size = 256  # AlexNet
    gaze_size = 3
    gaze_lstm_hidden_size = 64
    gaze_lstm_projected_size = 128
    temporal_projected_size = 128
    queue_size = 32
    learning_rate = 0.0001
    momentum = 0.9
    weight_decay = 1e-4
    eval_freq = 1  # epoch
    print_freq = 1  # iteration
    dataset_path = '../../gaze-net/gaze_dataset'
    # dataset_path = '../../gaze-net/gaze_dataset'
    img_size = (224, 224)
    extractor = True  # fine-tune the last two layers of feat_extractor or not
    log_path = '../log'
    logger = Logger(log_path, 'multiple')

    # define model
    if extractor == False:
        arch = 'alexnet'
        extractor_model = FeatureExtractor(arch=arch)
        extractor_model.features = torch.nn.DataParallel(
            extractor_model.features)
        # extractor_model.cuda()      # uncomment this line if using cpu
        extractor_model.eval()
    else:
        extractor_model = None

    model = MultipleAttentionModel(num_class,
                                   cnn_feat_size,
                                   gaze_size,
                                   gaze_lstm_hidden_size,
                                   gaze_lstm_projected_size,
                                   temporal_projected_size,
                                   queue_size,
                                   extractor=extractor)
    model.cuda()

    # define loss and optimizer
    criterion = nn.CrossEntropyLoss().cuda()
    param_list = []
    for i, param in enumerate(model.parameters()):
        if param.requires_grad == True:
            print(param.size())
            param_list.append(param)
    optimizer = torch.optim.SGD(param_list,
                                learning_rate,
                                momentum=momentum,
                                weight_decay=weight_decay)

    # define generator
    trainGenerator = gaze_gen.GazeDataGenerator(validation_split=0.2)
    train_data = trainGenerator.flow_from_directory(
        dataset_path,
        subset='training',
        crop=False,
        batch_size=batch_size,
        target_size=img_size,
        class_mode='sequence_pytorch',
        time_skip=time_skip)
    # small dataset, error using validation split
    val_data = trainGenerator.flow_from_directory(
        dataset_path,
        subset='validation',
        crop=False,
        batch_size=batch_size,
        target_size=img_size,
        class_mode='sequence_pytorch',
        time_skip=time_skip)
    # val_data = train_data

    # start Training
    para = {
        'bs': batch_size,
        'img_size': img_size,
        'num_class': num_class,
        'print_freq': print_freq
    }
    if TRAIN:
        print("get into training mode")
        best_acc = 0
        # acc = validate(val_data, extractor_model, model, criterion, 0, logger, para, False)
        for epoch in range(epochs):
            adjust_learning_rate(optimizer, epoch, learning_rate)
            print('Epoch: {}'.format(epoch))
            # train for one epoch
            train(train_data, extractor_model, model, criterion, optimizer,
                  epoch, logger, para)

            # evaluate on validation set
            if epoch % eval_freq == 0 or epoch == epochs - 1:
                acc = validate(val_data, extractor_model, model, criterion,
                               epoch, logger, para, False)
                is_best = acc > best_acc
                best_acc = max(acc, best_acc)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'arch': arch,
                        'state_dict': model.state_dict(),
                        'best_acc': best_acc,
                        'optimizer': optimizer.state_dict(),
                    }, is_best)
    else:
        model = load_checkpoint(model)
        print("get into testing and visualization mode")
        print("visualization for training data")
        vis_data_path = '../vis/train/'
        if not os.path.exists(vis_data_path):
            os.makedirs(vis_data_path)
        acc = validate(train_data, extractor_model, model, criterion, -1, \
                        logger, para, True, vis_data_path)
        print("visualization for validation data")
        vis_data_path = '../vis/val/'
        if not os.path.exists(vis_data_path):
            os.makedirs(vis_data_path)
        acc = validate(val_data, extractor_model, model, criterion, -1, \
                        logger, para, True, vis_data_path)
Exemplo n.º 9
0
def test(model, test_loader, criterion, args):
    number = args.test_path.split('test_')[1]
    save_path = './output/char0' + number
    print(f'Saving data in {save_path}')
    # set model path
    if args.load is not False:
        _, log_path = writer_builder(args.log_path,
                                     args.model_name,
                                     load=args.load)
        model_path = os.path.join(log_path,
                                  f'{args.model_name}_{args.scale}x.pt')
    # load model parameters
    checkpoint = torch.load(model_path, map_location=f'cuda:{args.gpu_id}')

    # try-except to compatible
    try:
        model.load_state_dict(checkpoint['state_dict'])
    except:
        print('Warning: load older version')
        model.feature = nn.Sequential(*model.feature, *model.bottle)
        model.bottle = nn.Sequential()
        model.load_state_dict(checkpoint['state_dict'], strict=False)

    model.eval()

    # declare content loss
    feature_extractor = FeatureExtractor().cuda()
    feature_extractor.eval()

    err = 0.0

    # out2csv
    i = 0  # count the number of loops
    j = 0  # count the number of data

    for data in tqdm(test_loader, desc=f'scale: {args.scale}'):
        inputs, target, _ = data
        inputs, target = inputs.cuda(), target.cuda()

        # normalize inputs and target
        # inputs = input_scaler.fit(inputs)
        # target = target_scaler.fit(target)

        pred = model(inputs)

        # denormalize
        # pred = input_scaler.inverse_transform(pred)

        # out2csv
        while j - (i * args.batch_size) < pred.size(0):
            out2csv(inputs,
                    f'test_{int(j/args.test_num)+1}',
                    'input',
                    j - (i * args.batch_size),
                    save_path,
                    args.stroke_length,
                    spec_flag=True)
            out2csv(pred,
                    f'test_{int(j/args.test_num)+1}',
                    'output',
                    j - (i * args.batch_size),
                    save_path,
                    args.stroke_length,
                    spec_flag=True)
            out2csv(target,
                    f'test_{int(j/args.test_num)+1}',
                    'target',
                    j - (i * args.batch_size),
                    save_path,
                    args.stroke_length,
                    spec_flag=True)
            j += args.test_num
        i += 1

        # MSE loss
        mse_loss = args.alpha * criterion(pred - inputs, target - inputs)

        # content loss
        gen_feature = feature_extractor(pred)
        real_feature = feature_extractor(target)
        content_loss = args.beta * criterion(gen_feature, real_feature)

        # for compatible
        loss = content_loss + mse_loss
        err += loss.sum().item() * inputs.size(0)

    err /= len(test_loader.dataset)
    print(f'test error:{err:.4f}')
Exemplo n.º 10
0
def main(num_epochs=10, embedding_dim=256, data_dir="data/"):
    """ Function to train the model.
    
    Args:
        num_epochs: int
            Number of full dataset iterations to train the model.
        embedding_dim: int
            Output of the CNN model and input of the LSTM embedding size.
        data_dir: str
            Path to the folder of the data.
    """
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"WORKING WITH: {device}")

    # Define the paths for train and validation
    train_json_path = data_dir + "annotations/captions_train2014.json"
    train_root_dir = data_dir + "train2014"
    valid_json_path = data_dir + "annotations/captions_val2014.json"
    valid_root_dir = data_dir + "val2014"

    transform = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    train_dataset = CocoDataset(json_path=train_json_path,
                                root_dir=train_root_dir,
                                transform=transform)

    train_coco_dataset = get_data_loader(train_dataset, batch_size=128)

    valid_dataset = CocoDataset(json_path=valid_json_path,
                                root_dir=valid_root_dir,
                                transform=transform)

    valid_coco_dataset = get_data_loader(valid_dataset, batch_size=1)

    encoder = FeatureExtractor(embedding_dim).to(device)
    decoder = CaptionGenerator(embedding_dim, 512,
                               len(train_dataset.vocabulary), 1).to(device)

    criterion = nn.CrossEntropyLoss()
    # params = list(decoder.parameters()) + list(encoder.linear.parameters()) + list(encoder.bn.parameters())
    params = list(decoder.parameters()) + list(
        encoder.linear.parameters()) + list(encoder.bn.parameters())
    optimizer = optim.Adam(params, lr=0.01)

    print(f"TRAIN DATASET: {len(train_coco_dataset)}")
    print(f"VALID DATASET: {len(valid_coco_dataset)}")

    total_step = len(train_coco_dataset)
    for epoch in range(num_epochs):
        encoder.train()
        decoder.train()
        train_loss = 0.0
        valid_loss = 0.0
        for i, (images, captions,
                descriptions) in enumerate(train_coco_dataset):

            # targets = pack_padded_sequence(caption, 0, batch_first=True)[0]

            images = images.to(device)
            captions = captions.to(device)
            # targets = pack_padded_sequence(captions, lengths, batch_first=True)[0]

            features = encoder(images)
            outputs = decoder(features, captions)

            loss = criterion(outputs.view(-1, len(train_dataset.vocabulary)),
                             captions.view(-1))
            # bleu = calculate_bleu(decoder, features, descriptions, coco_dataset)
            # print(bleu)

            encoder.zero_grad()
            decoder.zero_grad()

            loss.backward()
            optimizer.step()

            # Print log info
            train_loss += loss.item()
            '''
            if i % 10 == 0:
                print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}, Perplexity: {:5.4f}'
                      .format(epoch, num_epochs, i, total_step, loss.item(), np.exp(loss.item()))) 
            '''

            # Save the model checkpoints
            if (i + 1) % 1000 == 0:
                torch.save(
                    decoder.state_dict(),
                    os.path.join("models",
                                 'decoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
                torch.save(
                    encoder.state_dict(),
                    os.path.join("models",
                                 'encoder-{}-{}.ckpt'.format(epoch + 1,
                                                             i + 1)))
        encoder.eval()
        decoder.eval()
        bleu = 0.0
        for i, (images, captions,
                descriptions) in enumerate(valid_coco_dataset):
            if (i > 80000):
                break
            images = images.to(device)
            captions = captions.to(device)
            features = encoder(images)
            outputs = decoder(features, captions)
            loss = criterion(outputs.view(-1, len(train_dataset.vocabulary)),
                             captions.view(-1))
            valid_loss += loss.item()
            bleu += calculate_bleu(decoder, features, descriptions,
                                   train_coco_dataset)
        # print(f"BLEU: {bleu / 10000}")
        print(
            "Epoch: {}, Train Loss: {:.4f}, Valid Loss: {:.4f}, BLEU: {:.4f}".
            format(epoch, train_loss / len(train_coco_dataset),
                   valid_loss / 80000, bleu / 80000))