def test_01_model_train(self):
     model_file = DIRECTORY_MODELS + 'arima.pickle'
     date = '2018-11-20'
     duration = 30
     country = None
     model(date, duration, country)
     self.assertTrue(os.path.exists(model_file))
예제 #2
0
def train(model, train_dataloader, val_dataloader, optimizer,
          criteriaOfEachLevel, n_ary, decode_level, eval_point, save_path):

    total_batch = len(train_dataloader)
    eval_point = int(eval_point * total_batch)
    total_acc = 0
    for i, batch in enumerate(train_dataloader):

        batch_size = len(batch)
        optimizer.zero_grad()
        loss = 0
        acc = 0

        for j, sample in enumerate(batch):
            # source to reference
            output = model(sample["source"]["syntax"],
                           sample["reference"]["syntax"])
            output = output.view(-1, output.size(2))
            target = sample["reference"]["label"].to(cuda)
            loss += computeLossByLevel(output, target,
                                       sample["reference"]["level"],
                                       criteriaOfEachLevel)
            acc += ((output.argmax(1) == target).view(-1, n_ary).all(1)).sum()

            # reference to source
            output = model(sample["reference"]["syntax"],
                           sample["source"]["syntax"])
            output = output.view(-1, output.size(2))
            target = sample["source"]["label"].to(cuda)
            loss += computeLossByLevel(output, target,
                                       sample["source"]["level"],
                                       criteriaOfEachLevel)
            acc += ((output.argmax(1) == target).view(-1, n_ary).all(1)).sum()

        total_acc += acc
        acc = acc / (2 * batch_size * decode_level)
        loss = loss / (2 * batch_size * decode_level)
        print("Processing {:05d}/{} batch. loss: {:.5f} accuracy: {:.4f}\r".
              format(i, total_batch, loss, acc),
              end='')

        loss.backward()
        optimizer.step()

        if i % eval_point == 0:
            print("\nStart evaluation...\n")
            eval_acc = evaluate(model, val_dataloader, n_ary, decode_level,
                                criteriaOfEachLevel)
            path = os.path.join(
                save_path,
                "model_{}_{}_{:.3f}.pt".format(iteration, i / eval_point,
                                               eval_acc))
            torch.save(model.state_dict(), path)

    print("Training Accuracy: {}".format(total_acc /
                                         len(train_dataloader.dataset)))
 def test_02_model_predict(self):
     key = 'arima'
     date = '2018-11-20'
     duration = 30
     country = None
     result = model(date, duration, country)
     self.assertTrue(key in result)
예제 #4
0
    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }
예제 #5
0
 def __init__(self,
              batch_size=500,
              lr=0.001,
              training_times=500
              ):
     """
     auto-encoder的实现
     :param batch_size:每批图片的数量
     :param lr: 学习率
     :param training_times:epoch,训练次数
     """
     self.batch_size = batch_size
     self.lr = lr
     self.training_times = training_times
     self.device = t.device('cuda' if t.cuda.is_available() else 'cpu')
     self.train_data, self.test_data = data_create(self.batch_size)
     self.model = model().to(self.device)
     self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
     self.loss = MSELoss()
예제 #6
0
def predict(sentence, bpe_model, model):
    model.eval()

    if isinstance(sentence, str):
        sentence = preprocess(sentence)
        tokens = bpe_tokenizer(sentence)
    else:
        tokens = [int(token) for token in sentence]

    src_indexes = tokens

    # convert to tensor format
    # since the inference is done on single sentence, batch size is 1
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(1).to(device)
    # src_tensor => [seq_len, 1]

    src_length = torch.LongTensor([len(src_indexes)])
    # src_length => [1]

    with torch.no_grad():
        predictions = model(src_tensor, src_length)

    return predictions
예제 #7
0
def predict():
    # Check date parameter in request
    if 'date' in request.args:
        date = request.args['date']
    else:
        return "Error: No date parameter was provided."
    # Check country parameter in request
    if 'country' in request.args:
        country = request.args['country']
    else:
        country = None
    # Check duration parameter in request
    if 'duration' in request.args:
        duration = request.args['duration']
        if duration == '':
            duration = 30
        else:
            duration = int(duration)
    else:
        duration = 30
    # Call model with parameters
    result = model(date, duration, country)
    # Return result
    return jsonify({'data': result})
예제 #8
0
import matplotlib.pyplot as plt
import sys
import json
import time
import os

from OCC.Core.gp import gp_Pnt, gp_Vec, gp_Dir
from OCC.Core.gp import gp_Ax1, gp_Ax2, gp_Ax3
from OCC.Core.gp import gp_XYZ
from OCC.Core.gp import gp_Lin

from src.base import create_tempnum
from src.model import model, model_base
from src.jsonfile import write_json

obj = model(cfgfile="./cfg/model.json")
print(obj.cfg["name"])

md0 = obj.set_model(name="surf0")
md1 = obj.set_model(name="surf1")
md2 = obj.set_model(name="surf")
md2.rim = obj.make_PolyWire(skin=None)
md3 = model_base(meta={"name": "surf3"})
md3.axs.Translate(gp_Pnt(), gp_Pnt(0, 0, 100))

meta = {}
meta["surf1"] = md1.export_dict()
meta["surf2"] = md2.export_dict()
meta["surf3"] = md3.export_dict()
write_json(create_tempnum("model", obj.tmpdir, ext=".json"), meta)
예제 #9
0
    emoji_to_pic = {
        'happy': None,
        'disgust': None,
        'sad': None,
        'surprise': None,
        'fear': None,
        'angry': None
    }

    # ATTENTION: CHANGE THE '\\' A/C TO YOUR OS
    files = glob.glob(EMOJI_FILE_PATH + '\\*.png')

    logger.info('loading the emoji png files in memory ...')
    for file in tqdm.tqdm(files):
        logger.debug('file path: {}'.format(file))
        # ATTENTION: CHANGE THE '\\' A/C TO YOUR OS
        emoji_to_pic[file.split('\\')[-1].split('.')[0]] = cv.imread(file, -1)

    X = tf.placeholder(tf.float32, shape=[None, 48, 48, 1])

    keep_prob = tf.placeholder(tf.float32)

    y_conv = model(X, keep_prob)

    saver = tf.train.Saver()

    with tf.Session(config=config) as sess:
        saver.restore(sess, os.path.join(CHECKPOINT_SAVE_PATH, 'model.ckpt'))
        logger.info('Opening the camera for getting the video feed ...')
        from_cam(sess)
예제 #10
0
파일: main.py 프로젝트: alpop/DALI
def train(args):
    if args.amp:
        amp_handle = amp.init(enabled=args.fp16)

    args.distributed = False
    if 'WORLD_SIZE' in os.environ:
        args.distributed = int(os.environ['WORLD_SIZE']) > 1

    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        torch.distributed.init_process_group(backend='nccl', init_method='env://')
        args.N_gpu = torch.distributed.get_world_size()
    else:
        args.N_gpu = 1

    dboxes = dboxes300_coco()
    encoder = Encoder(dboxes)
    cocoGt = get_coco_ground_truth(args)

    ssd300 = model(args)
    args.learning_rate = args.learning_rate * args.N_gpu * (args.batch_size / 32)
    iteration = 0
    loss_func = Loss(dboxes)

    loss_func.cuda()

    optimizer = torch.optim.SGD(
        tencent_trick(ssd300), 
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay)

    scheduler = MultiStepLR(
        optimizer=optimizer, 
        milestones=args.multistep, 
        gamma=0.1)

    if args.fp16:
        if args.amp:
            optimizer = amp_handle.wrap_optimizer(optimizer)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.)

    val_dataloader, inv_map = get_val_dataloader(args)
    train_loader = get_train_loader(args, dboxes)

    acc = 0
    logger = Logger(args.batch_size, args.local_rank)
    
    for epoch in range(0, args.epochs):
        logger.start_epoch()
        scheduler.step()

        iteration = train_loop(
            ssd300, loss_func, epoch, optimizer, 
            train_loader, iteration, logger, args)

        logger.end_epoch()

        if epoch in args.evaluation:
            acc = evaluate(ssd300, val_dataloader, cocoGt, encoder, inv_map, args)
            if args.local_rank == 0:
                print('Epoch {:2d}, Accuracy: {:4f} mAP'.format(epoch, acc))

        if args.data_pipeline == 'dali':
            train_loader.reset()

    return acc, logger.average_speed()
예제 #11
0
def main():
    image_root = os.path.join(cfg.DATA_DIR, cfg.DATASET, 'image')
    datasetname = cfg.DATASET
    dataset = Dataset(datasetname)
    query_lists, _ = dataset.get_retrieval_list()
    feature_file = open(os.path.join(cfg.FEATURE_DIR, cfg.DATASET, \
                                     cfg.MODEL, cfg.DISTANCE_METRIC, \
                                     cfg.SAVE_NAME), 'rb')
    feature_dict = cPickle.load(feature_file)

    save_root = os.path.join(cfg.RESULT_DIR, cfg.DATASET, cfg.MODEL,
                             cfg.DISTANCE_METRIC)
    if not os.path.exists(save_root):
        os.makedirs(save_root)

    net = model(cfg.MODEL, cfg.CAFFEMODEL, cfg.PROTOTXT, cfg.FEATURE_NAME[0],
                cfg.GPU_ID)

    for idx, query_list in enumerate(query_lists):
        query_name, query_id = query_list.strip().split(' ')

        if query_name in feature_dict.keys():
            q_feature = copy.deepcopy(feature_dict[query_name]['feature'])
            q_id = copy.deepcopy(feature_dict[query_name]['id'])
        else:
            image_path = os.path.join(image_root, query_name + '.jpg')
            image = cv2.imread(image_path)
            q_feature = copy.deepcopy(net.forward(image, cfg.FEATURE_NAME[0]))
            q_id = dataset.get_image_id(query_name)

        retrieval_results = []

        for key in feature_dict.keys():
            p_feature = feature_dict[key]['feature']
            p_id = feature_dict[key]['id']

            if cfg.DISTANCE_METRIC == 'L2':
                tmp = norm_L2_distance(q_feature, p_feature)
            elif cfg.DISTANCE_METRIC == 'cosin':
                tmp = norm_cosin_distance(q_feature, p_feature)

            retrieval_results.append({
                'filename': key,
                'distance': tmp,
                'id': p_id
            })

        retrieval_results.sort(lambda x, y: cmp(y['distance'], x['distance']))
        #pdb.set_trace()

        f = open(os.path.join(save_root, query_name + '.txt'), 'w')

        # query image is also store in dataset
        for i in range(0, cfg.TOP_K + 1):
            filename = retrieval_results[i]['filename']
            distance = retrieval_results[i]['distance']
            f.writelines(
                str(query_name) + '_' + str(filename) + ' ' + str(distance) +
                '\n')
        f.close()

        view_bar(idx, len(query_lists))
예제 #12
0
def main():

    # args = parser.parse_args()
    args = Opts()
    enc = encoder.get_encoder(args.model_name)
    hparams = model.default_hparams()

    with open(os.path.join('models', args.model_name, 'hparams.json')) as f:
        # hparams.override_from_dict(json.load(f))
        hparams.override_from_dict(json.loads(f.read()))

    if args.sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    if args.model_name == '345M':
        args.memory_saving_gradients = True
        if args.optimizer == 'adam':
            args.only_train_transformer_layers = True

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        if args.val_every > 0:
            val_context = tf.placeholder(tf.int32, [args.val_batch_size, None])
            val_output = model.model(hparams=hparams, X=val_context)
            val_loss = tf.reduce_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=val_context[:, 1:],
                    logits=val_output['logits'][:, :-1]))
            val_loss_summary = tf.summary.scalar('val_loss', val_loss)

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=args.sample_length,
                                           context=context,
                                           batch_size=args.batch_size,
                                           temperature=1.0,
                                           top_k=args.top_k,
                                           top_p=args.top_p)

        all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        train_vars = [v for v in all_vars if '/h' in v.name
                      ] if args.only_train_transformer_layers else all_vars

        if args.optimizer == 'adam':
            opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)
        elif args.optimizer == 'sgd':
            opt = tf.train.GradientDescentOptimizer(
                learning_rate=args.learning_rate)
        else:
            exit('Bad optimizer:', args.optimizer)

        if args.accumulate_gradients > 1:
            if args.memory_saving_gradients:
                exit(
                    "Memory saving gradients are not implemented for gradient accumulation yet."
                )
            opt = AccumulatingOptimizer(opt=opt, var_list=train_vars)
            opt_reset = opt.reset()
            opt_compute = opt.compute_gradients(loss)
            opt_apply = opt.apply_gradients()
            summary_loss = tf.summary.scalar('loss', opt_apply)
        else:
            if args.memory_saving_gradients:
                opt_grads = memory_saving_gradients.gradients(loss, train_vars)
            else:
                opt_grads = tf.gradients(loss, train_vars)
            opt_grads = list(zip(opt_grads, train_vars))
            opt_apply = opt.apply_gradients(opt_grads)
            summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(var_list=all_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', args.model_name))
        elif args.restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', args.model_name))
        else:
            ckpt = tf.train.latest_checkpoint(args.restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc,
                              args.dataset,
                              args.combine,
                              encoding=args.encoding)
        data_sampler = Sampler(chunks)
        if args.val_every > 0:
            if args.val_dataset:
                val_chunks = load_dataset(enc,
                                          args.val_dataset,
                                          args.combine,
                                          encoding=args.encoding)
            else:
                val_chunks = chunks
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        if args.val_every > 0:
            # Sample from validation set once with fixed seed to make
            # it deterministic during training as well as across runs.
            val_data_sampler = Sampler(val_chunks, seed=1)
            val_batches = [[
                val_data_sampler.sample(1024)
                for _ in range(args.val_batch_size)
            ] for _ in range(args.val_batch_count)]

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                       global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            print('Generating samples...')
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < args.sample_num:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: args.batch_size * [context_tokens]})
                for i in range(min(args.sample_num - index, args.batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
            print(text)
            maketree(os.path.join(SAMPLE_DIR, args.run_name))
            with open(os.path.join(SAMPLE_DIR, args.run_name,
                                   'samples-{}').format(counter),
                      'w',
                      encoding=args.encoding) as fp:
                fp.write('\n'.join(all_text))

        def validation():
            print('Calculating validation loss...')
            losses = []
            for batch in tqdm.tqdm(val_batches):
                losses.append(
                    sess.run(val_loss, feed_dict={val_context: batch}))
            v_val_loss = np.mean(losses)
            v_summary = sess.run(val_loss_summary,
                                 feed_dict={val_loss: v_val_loss})
            summary_log.add_summary(v_summary, counter)
            summary_log.flush()
            print('[{counter} | {time:2.2f}] validation loss = {loss:2.2f}'.
                  format(counter=counter,
                         time=time.time() - start_time,
                         loss=v_val_loss))

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                if counter % args.sample_every == 0:
                    generate_samples()
                if args.val_every > 0 and (counter % args.val_every == 0
                                           or counter == 1):
                    validation()

                if args.accumulate_gradients > 1:
                    sess.run(opt_reset)
                    for _ in range(args.accumulate_gradients):
                        sess.run(opt_compute,
                                 feed_dict={context: sample_batch()})
                    (v_loss, v_summary) = sess.run((opt_apply, summaries))
                else:
                    (_, v_loss, v_summary) = sess.run(
                        (opt_apply, loss, summaries),
                        feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(counter=counter,
                            time=time.time() - start_time,
                            loss=v_loss,
                            avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
예제 #13
0
def train_main(dataset,
               model_name='117M',
               seed=None,
               batch_size=1,
               sample_length=1023,
               sample_num=50,
               sample_every=100,
               run_name='dnd_biographies08',
               restore_from='latest',
               mode="test",
               max_iterations=50000,
               loss_threshold=0.8,
               save_every=1000):

    enc = encoder.get_encoder(model_name)
    hparams = model.default_hparams()
    with open(os.path.join('models', model_name, 'hparams.json')) as f:
        hparams.override_from_dict(json.load(f))

    if sample_length is None:
        sample_length = hparams.n_ctx // 2
    elif sample_length > hparams.n_ctx:
        raise ValueError("Can't get samples longer than window size: %s" %
                         hparams.n_ctx)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [batch_size, None])
        np.random.seed(seed)
        tf.set_random_seed(seed)
        output = model.model(hparams=hparams, X=context)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        tf_sample = sample.sample_sequence(hparams=hparams,
                                           length=sample_length,
                                           context=context,
                                           batch_size=batch_size,
                                           temperature=1.0,
                                           top_k=40)

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
        opt = tf.train.AdamOptimizer(1e-4).minimize(loss, var_list=train_vars)

        saver = tf.train.Saver(var_list=train_vars,
                               max_to_keep=5,
                               keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, run_name))
            if ckpt is None:
                # Get fresh GPT weights if new run.
                ckpt = tf.train.latest_checkpoint(
                    os.path.join('models', model_name))
        elif restore_from == 'fresh':
            ckpt = tf.train.latest_checkpoint(
                os.path.join('models', model_name))
        else:
            ckpt = tf.train.latest_checkpoint(restore_from)
        print('Loading checkpoint', ckpt)
        saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(enc, dataset)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        counter = 1
        if os.path.exists(os.path.join(CHECKPOINT_DIR, run_name, 'counter')):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, run_name,
                             'model-{}').format(counter))
            saver.save(sess,
                       os.path.join(CHECKPOINT_DIR, run_name, 'model'),
                       global_step=counter)
            with open(os.path.join(CHECKPOINT_DIR, run_name, 'counter'),
                      'w') as fp:
                fp.write(str(counter) + '\n')

        def generate_samples():
            context_tokens = data_sampler.sample(1)
            all_text = []
            index = 0
            while index < sample_num or sample_num == 0:
                out = sess.run(
                    tf_sample,
                    feed_dict={context: batch_size * [context_tokens]})
                for i in range(min(sample_num - index, batch_size)):
                    text = enc.decode(out[i])
                    text = '======== SAMPLE {} ========\n{}\n'.format(
                        index + 1, text)
                    all_text.append(text)
                    index += 1
                    print(text)
            # print(''.join(all_text))
            maketree(os.path.join(SAMPLE_DIR, run_name))
            with open(
                    os.path.join(SAMPLE_DIR, run_name,
                                 'samples-{}').format(counter), 'w') as fp:
                fp.write('\n'.join(all_text))

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            if mode == "train":
                while True and counter <= max_iterations:
                    if counter % save_every == 0:
                        save()
                    if counter % sample_every == 0:
                        generate_samples()

                    batch = [
                        data_sampler.sample(1024) for _ in range(batch_size)
                    ]

                    _, lv = sess.run((opt, loss), feed_dict={context: batch})

                    avg_loss = (avg_loss[0] * 0.99 + lv,
                                avg_loss[1] * 0.99 + 1.0)

                    print(
                        '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                        .format(counter=counter,
                                time=time.time() - start_time,
                                loss=lv,
                                avg=avg_loss[0] / avg_loss[1]))

                    counter += 1
                    if counter > 100:
                        if (avg_loss[0] / avg_loss[1]) < loss_threshold:
                            counter = max_iterations + 1
            else:
                generate_samples()
        except KeyboardInterrupt:
            print('interrupted')
            save()
예제 #14
0
    parser.add_argument('--fp16-mode',
                        type=str,
                        default='off',
                        choices=['off', 'static', 'amp'],
                        help='Half precission mode to use')
    opt = parser.parse_args()

    if opt.fp16_mode != 'off':
        opt.fp16 = True
        opt.amp = (opt.fp16_mode == 'amp')
    else:
        opt.fp16 = False
        opt.amp = False
    if opt.amp:
        amp_handle = amp.init(enabled=opt.fp16)
    model = model(opt)
    optimizer = torch.optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9)

    model = torch.nn.DataParallel(model, device_ids=[0, 1, 2, 3])

    if opt.fp16:
        print("INFO: Use Fp16")
        if opt.amp:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
            # optimizer = amp_handle.wrap_optimizer(optimizer)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=128.)
    # Prepare dataset
    print("INFO: Prepare Datasets")
    if opt.data_pipeline == 'dali':
        eii = ExternalInputIterator(opt.batch_size, opt.img_path,
예제 #15
0
def main():
    args = parser.parse_args()

    if args.dataset == 'lpd_5':
        tracks = ['Drums', 'Piano', 'Guitar', 'Bass', 'Strings']
    elif args.dataset == 'lpd_17':
        tracks = ['Drums', 'Piano', 'Chromatic Percussion', 'Organ', 'Guitar', 'Bass', 'Strings', 'Ensemble', 'Brass', 'Reed', 'Pipe', 'Synth Lead', 'Synth Pad', 'Synth Effects', 'Ethnic', 'Percussive', 'Sound Effects']
    else:
        print('invalid dataset name.')
        exit()
    trc_len = len(tracks)
    note_size = 84
    time_note = note_size*trc_len + 1
    end_note = note_size*trc_len + 2
    hparams = HParams(**{
      "n_vocab": end_note+1,
      "n_ctx": 1024,
      "n_embd": 768,
      "n_head": 12,
      "n_layer": 12
    })

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    config.gpu_options.visible_device_list = args.gpu
    config.graph_options.rewrite_options.layout_optimizer = rewriter_config_pb2.RewriterConfig.OFF
    with tf.Session(config=config) as sess:
        context = tf.placeholder(tf.int32, [args.batch_size, None])
        context_in = randomize(context, hparams, args.noise)
        output = model.model(hparams=hparams, X=context_in)
        loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=context[:, 1:], logits=output['logits'][:, :-1]))

        train_vars = [v for v in tf.trainable_variables() if 'model' in v.name]

        opt = tf.train.AdamOptimizer(learning_rate=args.learning_rate)

        opt_grads = tf.gradients(loss, train_vars)
        opt_grads = list(zip(opt_grads, train_vars))
        opt_apply = opt.apply_gradients(opt_grads)
        summary_loss = tf.summary.scalar('loss', loss)

        summary_lr = tf.summary.scalar('learning_rate', args.learning_rate)
        summaries = tf.summary.merge([summary_lr, summary_loss])

        summary_log = tf.summary.FileWriter(
            os.path.join(CHECKPOINT_DIR, args.run_name))

        saver = tf.train.Saver(
            var_list=train_vars,
            max_to_keep=5,
            keep_checkpoint_every_n_hours=2)
        sess.run(tf.global_variables_initializer())

        if args.restore_from == 'latest':
            ckpt = tf.train.latest_checkpoint(
                os.path.join(CHECKPOINT_DIR, args.run_name))
            print('Loading checkpoint', ckpt)
            saver.restore(sess, ckpt)
        elif args.restore_from != 'fresh':
            ckpt = tf.train.latest_checkpoint(args.restore_from)
            print('Loading checkpoint', ckpt)
            saver.restore(sess, ckpt)

        print('Loading dataset...')
        chunks = load_dataset(args.input)
        data_sampler = Sampler(chunks)
        print('dataset has', data_sampler.total_size, 'tokens')
        print('Training...')

        counter = 1
        counter_path = os.path.join(CHECKPOINT_DIR, args.run_name, 'counter')
        if os.path.exists(counter_path):
            # Load the step number if we're resuming a run
            # Add 1 so we don't immediately try to save again
            with open(counter_path, 'r') as fp:
                counter = int(fp.read()) + 1

        def save():
            maketree(os.path.join(CHECKPOINT_DIR, args.run_name))
            print(
                'Saving',
                os.path.join(CHECKPOINT_DIR, args.run_name,
                             'model-{}').format(counter))
            saver.save(
                sess,
                os.path.join(CHECKPOINT_DIR, args.run_name, 'model'),
                global_step=counter)
            with open(counter_path, 'w') as fp:
                fp.write(str(counter) + '\n')

        def sample_batch():
            return [data_sampler.sample(1024) for _ in range(args.batch_size)]

        avg_loss = (0.0, 0.0)
        start_time = time.time()

        try:
            while True:
                if counter % args.save_every == 0:
                    save()
                (_, v_loss, v_summary) = sess.run(
                    (opt_apply, loss, summaries),
                    feed_dict={context: sample_batch()})

                summary_log.add_summary(v_summary, counter)

                avg_loss = (avg_loss[0] * 0.99 + v_loss,
                            avg_loss[1] * 0.99 + 1.0)

                print(
                    '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f}'
                    .format(
                        counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1]))

                counter += 1
        except KeyboardInterrupt:
            print('interrupted')
            save()
예제 #16
0
from src.utils import Timer
from src.dataset import Dataset
import copy

if __name__ == '__main__':
    dataset = cfg.DATASET
    layername = cfg.FEATURE_NAME[0]
    modelname = cfg.MODEL
    caffemodel = cfg.CAFFEMODEL
    prototxt = cfg.PROTOTXT
    gpu_id = cfg.GPU_ID

    print('model: {}, caffemodel: {}, prototxt: {}'.format(
        modelname, caffemodel, prototxt))

    net = model(modelname, caffemodel, prototxt, layername, gpu_id)

    dataset = Dataset(dataset)

    image_iters = dataset.get_image_list()

    #image_iters = image_iters[0:100]

    for idx, image_iter in enumerate(image_iters):
        image_name = image_iter.split('.')[0]

        image_id = dataset.get_image_id(image_name)

        if image_id:
            image_path = os.path.join(dataset.data_path, image_name + '.jpg')
            image = cv2.imread(image_path)
예제 #17
0
    if args.input_profile is not None:
        try:
            input_profile = pd.read_csv(args.input_profile)
            logger.info('Input profile data loaded from %s', args.input)
        except FileNotFoundError as e:
            logger.error('Input profile file %s not found', args.input)
            raise SystemExit("Provide valid path!")

    if args.step == 'upload':
        upload_to_s3(args.lfp, **config['s3_upload'])
    if args.step == 'download':
        download_from_s3(args.lfp, **config['s3_download'])
    elif args.step == 'clean':
        output = clean_base(input, **config['clean'])
    elif args.step == 'featurize':
        output = featurize(input, input_profile, args.lfp,
                           **config['featurize'])
    elif args.step == 'model':
        model(input, args.lfp, **config['model'])
    elif args.step == 'score':
        output = scoring(args.lfp, **config['score'])
    elif args.step == 'create_db':
        if args.truncate:
            create_score_db(input, 1, **config['database'])
        else:
            create_score_db(input, 0, **config['database'])

    if args.output is not None and output is not None:
        output.to_csv(args.output, index=False)
        logger.info("Output saved to %s" % args.output)
예제 #18
0
파일: train.py 프로젝트: ianrowan/DeepWave
def train(sess,
          data,
          labels,
          steps,
          run_name,
          batch_size=1,
          n_heads=None,
          n_layers=None,
          learning_rate=0.0001,
          print_each=1,
          save_every=1000,
          accumulate=5,
          use_class_entropy=False,
          model_path="checkpoint/"):

    model_path = os.path.join(model_path, run_name)

    if not os.path.exists(model_path):
        os.mkdir(model_path)

    new_run = 'counter' not in os.listdir(model_path)

    hparams = model.default_hparams()
    #Set HyperParams
    if n_layers: hparams.n_layer = n_layers
    if n_heads: hparams.n_head = n_heads
    if os.path.exists(model_path + "/hparams.json"):
        with open(os.path.join(model_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

    #Spectrogram dimensions
    d_shape = np.shape(data)
    print(d_shape)
    hparams.n_timestep = d_shape[1]
    hparams.n_freq = d_shape[2]
    hparams.n_cat = len(labels[0])

    #Create TF graph
    inp_specs = tf.placeholder(
        tf.float32, [batch_size, hparams.n_timestep, hparams.n_freq])
    logits = model.model(hparams, inp_specs, reuse=tf.AUTO_REUSE)
    #Loss tensor = Softmax cross entropy
    label_exp = tf.placeholder(tf.int8, [batch_size, hparams.n_cat])
    loss = tf.reduce_mean(
        tf.nn.softmax_cross_entropy_with_logits_v2(labels=label_exp,
                                                   logits=logits['logits']))

    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    print("Using {} Parameter Network".format(str(len(all_vars))))

    lr = tf.placeholder(tf.float32)
    if accumulate > 1:
        #Train step using AdamOtimizer with Accumulating gradients
        opt = AccumulatingOptimizer(
            opt=tf.train.AdamOptimizer(learning_rate=lr), var_list=all_vars)
        opt_reset = opt.reset()
        opt_compute = opt.compute_gradients(loss)
        opt_apply = opt.apply_gradients()
    else:
        opt = tf.train.AdamOptimizer(learning_rate=lr)
        opt_grads = tf.gradients(loss, all_vars)
        opt_grads = list(zip(opt_grads, all_vars))
        opt_apply = opt.apply_gradients(opt_grads)

    #Create saveable graph and checkpoint + counter
    saver = tf.train.Saver(var_list=all_vars)
    sess.run(tf.global_variables_initializer())
    if new_run:
        saver.save(sess, model_path + "/{}.ckpt".format(run_name))
    ckpt = tf.train.latest_checkpoint(model_path)
    print('Restoring checkpoint', ckpt)
    saver.restore(sess, ckpt)

    #Training SetUp
    #Get counter
    counter = 1
    counter_path = os.path.join(model_path, 'counter')
    if os.path.exists(counter_path):
        with open(counter_path, 'r') as fp:
            counter = int(fp.read()) + 1
    counter_base = counter

    def save():
        print('Saving',
              os.path.join(model_path, 'model-{}').format(counter - 1))
        saver.save(sess,
                   os.path.join(model_path, 'model'),
                   global_step=counter - 1)
        with open(counter_path, 'w') as fp:
            fp.write(str(counter - 1) + '\n')

    def next_batch(num, data, lab):
        '''
        Return a total of `num` random samples and labels.
        '''
        idx = np.arange(0, len(data))
        np.random.shuffle(idx)
        idx = idx[:num]
        data_shuffle = [data[i] for i in idx]
        labels_shuffle = [lab[i] for i in idx]
        return np.asarray(data_shuffle), np.asarray(labels_shuffle)

    avg_loss = (0.0, 0.0)
    start_time = time.time()

    def class_entropy(y):
        y = np.sum(y, 0)
        e = sum([(i / sum(y)) * np.log(i / sum(y)) if i > 0 else 0 for i in y])

        return np.abs(1 - (-np.log(1 / len(y)) + e))

    try:
        while counter < (counter_base + steps):
            if (counter - 1) % save_every == 0 and counter > 1:
                save()

            # Get batch of specified size
            x, lab = next_batch(batch_size, data, labels)
            lrate = learning_rate * class_entropy(
                lab) if use_class_entropy else learning_rate

            if accumulate > 1:
                sess.run(opt_reset)
                #Run Gradient accumulation steps
                for _ in range(accumulate):
                    sess.run(opt_compute,
                             feed_dict={
                                 inp_specs: x,
                                 label_exp: lab
                             })
            else:
                _, v_loss = sess.run((opt_apply, loss),
                                     feed_dict={
                                         inp_specs: x,
                                         label_exp: lab,
                                         lr: lrate,
                                         "model/drop:0": 1.0
                                     })

            avg_loss = (avg_loss[0] * 0.99 + v_loss, avg_loss[1] * 0.99 + 1.0)
            print(
                '[{counter} | {time:2.2f}] loss={loss:2.2f} avg={avg:2.2f} lrate={lrate}'
                .format(counter=counter,
                        time=time.time() - start_time,
                        loss=v_loss,
                        avg=avg_loss[0] / avg_loss[1],
                        lrate=str(lrate)))
            if counter % print_each == 0:
                sample = next_batch(batch_size, data, labels)
                out = sess.run(logits,
                               feed_dict={
                                   inp_specs: sample[0],
                                   "model/drop:0": 1.0
                               })
                acc = sum(
                    np.argmax(np.asarray(out['logits']), axis=1) == np.argmax(
                        sample[1], axis=1)) / batch_size
                print("[Summary Step] Accuracy {}% for {} distribution".format(
                    str(acc * 100), str(np.sum(sample[1], 0))))
                print("Class Entropy: {}".format(str(class_entropy(
                    sample[1]))))
            counter += 1
        save()

    except KeyboardInterrupt:
        print('interrupted')
        save()
예제 #19
0
def predict(sess,
            data,
            run_name,
            batch_size,
            num_categories,
            category_names,
            model_path="checkpoint/"):

    model_path = os.path.join(model_path, run_name)

    # Load Hyperparams from model
    hparams = model.default_hparams()
    if os.path.exists(model_path + "/hparams.json"):
        with open(os.path.join(model_path, 'hparams.json')) as f:
            hparams.override_from_dict(json.load(f))

    d_shape = np.shape(data)
    print("Precicting for data: " + str(d_shape))
    hparams.n_timestep = d_shape[1]
    hparams.n_freq = d_shape[2]
    hparams.n_cat = num_categories

    # Create TF graph
    inp_specs = tf.placeholder(
        tf.float32, [batch_size, hparams.n_timestep, hparams.n_freq])
    prediction = model.model(hparams, inp_specs)

    # Get Model vars
    all_vars = [v for v in tf.trainable_variables() if 'model' in v.name]
    saver = tf.train.Saver(var_list=all_vars)
    sess.run(tf.global_variables_initializer())
    ckpt = tf.train.latest_checkpoint(model_path)
    saver.restore(sess, ckpt)

    predictions = np.zeros((len(data), num_categories))
    num_batches = int(np.ceil(len(data) / batch_size))

    for i in tqdm(range(num_batches)):
        c = batch_size

        if i * batch_size + c > len(data):
            add = (i * batch_size + c) - len(data)
            pred = sess.run(prediction,
                            feed_dict={
                                inp_specs:
                                np.concatenate((data[i * batch_size:],
                                                np.zeros(
                                                    (add, hparams.n_timestep,
                                                     hparams.n_freq)))),
                                "model/drop:0":
                                1.0
                            })['logits']
            predictions[i * batch_size:] = pred[:-add]
        else:

            predictions[i*batch_size: i*batch_size+c] =\
                sess.run(prediction, feed_dict={inp_specs: data[i*batch_size: i*batch_size+batch_size],
                                                "model/drop:0": 1.0})['logits']

    cats = np.argmax(predictions, axis=1)

    return {
        "raw": predictions,
        "category": cats,
        "predictName": ["N", "S", "V", "F", "Q"],
        "names": category_names
    }
예제 #20
0
def test(model_name, test_audio_name):

    csv_test_audio = csv_dir + test_audio_name + '/'

    init, net1_optim, net2_optim, all_optim, x, x_face_id, y_landmark, y_phoneme, y_lipS, y_maya_param, dropout, cost, \
    tensorboard_op, pred, clear_op, inc_op, avg, batch_size_placeholder, phase = model()

    # start tf graph
    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    max_to_keep = 20
    saver = tf.train.Saver(max_to_keep=max_to_keep)


    try_mkdir(pred_dir)

    # Test sess, load ckpt
    OLD_CHECKPOINT_FILE = model_dir + model_name + '/' + model_name +'.ckpt'

    saver.restore(sess, OLD_CHECKPOINT_FILE)
    print("Model loaded: " + model_dir + model_name)

    total_epoch_num = 1
    print(csv_test_audio)

    data_dir = {'train': {}, 'test': {}}
    data_dir['test']['wav'] = open(csv_test_audio + "test/wav.csv", 'r')
    data_dir['test']['clip_len'] = open(csv_test_audio + "test/clip_len.csv", 'r')
    cv_file_len = simple_read_clip_len(data_dir['test']['clip_len'])
    print('Loading wav_raw.txt file in {:}'.format(csv_test_audio))

    train_wav_raw = np.loadtxt(csv_test_audio + 'wav_raw.csv')
    test_wav_raw = train_wav_raw


    for epoch in range(0, total_epoch_num):
        # clear data file header

        # ============================== TRAIN SET CHUNK ITERATION ============================== #

        sess.run(clear_op)
        for key in ['train', 'test']:
            for lpw_key in data_dir[key].keys():
                data_dir[key][lpw_key].seek(0)

        print("===================== TEST/CV CHUNK - {:} ======================".format(csv_test_audio))
        eof = False
        chunk_num = 0
        chunk_size_sum = 0

        batch_size = test_wav_raw.shape[0]
        chunk_size = batch_size * batch_per_chunk_size

        while (not eof):
            cv_data, eof = read_chunk_data(data_dir, 'test', chunk_size)
            chunk_num += 1
            chunk_size_sum += len(cv_data['wav'])

            print('Load Chunk {:d}, size {:d}, total_size {:d} ({:2.2f})'
                  .format(chunk_num, len(cv_data['wav']), chunk_size_sum, chunk_size_sum / cv_file_len))

            full_idx_array = np.arange(len(cv_data['wav']))
            # np.random.shuffle(full_idx_array)
            for next_idx in range(0, int(np.floor(len(cv_data['wav']) / batch_size))):
                batch_idx_array = full_idx_array[next_idx * batch_size: (next_idx + 1) * batch_size]
                batch_x, batch_x_face_id, batch_x_pose, batch_y_landmark, batch_y_phoneme, batch_y_lipS, batch_y_maya_param = \
                    read_next_batch_easy_from_raw(test_wav_raw, cv_data, 'face_close', batch_idx_array, batch_size, n_steps, n_input, n_landmark,
                                         n_phoneme, n_face_id)
                npClose = np.loadtxt(lpw_dir + 'saved_param/maya_close_face.txt')
                batch_x_face_id = np.tile(npClose, (batch_x_face_id.shape[0], 1))


                test_pred, loss, _ = sess.run([pred, cost, inc_op],
                                            feed_dict={x: batch_x,
                                                       x_face_id: batch_x_face_id,
                                                       y_landmark: batch_y_landmark,
                                                       y_phoneme: batch_y_phoneme,
                                                       y_lipS: batch_y_lipS,
                                                       dropout: 0,
                                                       batch_size_placeholder: batch_x.shape[0],
                                                       phase: 0,
                                                       y_maya_param: batch_y_maya_param})


                def save_output(filename, npTxt, fmt):
                    f = open(filename, 'wb')
                    np.savetxt(f, npTxt, fmt=fmt)
                    f.close()

                try_mkdir(pred_dir + test_audio_name)

                def sigmoid(x):
                    return 1/(1+np.exp(-x))
                save_output(pred_dir + test_audio_name + "/mayaparam_pred_cls.txt",
                            np.concatenate([test_pred['jali'], sigmoid(test_pred['v_cls'])], axis=1), '%.4f')
                save_output(pred_dir + test_audio_name + "/mayaparam_pred_reg.txt",
                            np.concatenate([test_pred['jali'], test_pred['v_reg']], axis=1), '%.4f')
예제 #21
0
def evaluate(model,
             dataloader,
             n_ary,
             decode_level,
             criteriaOfEachLevel,
             printResult=False):
    def outputAccuracyEachLevel(accuracyEachLevel):
        out = []
        for level, acc in enumerate(accuracyEachLevel):
            out.append("Level {}: {}, ".format(level, acc))
        return ' '.join(out)

    model.eval()
    loss = 0
    correctEachLevel = [0] * decode_level
    countEachLevel = [0] * decode_level

    with torch.no_grad():
        for batch in tqdm(dataloader):
            for sample in batch:
                # source to reference
                output = model(sample["source"]["syntax"],
                               sample["reference"]["syntax"])
                output = output.view(-1, output.size(2))
                target = sample["reference"]["label"].to(cuda)
                loss += computeLossByLevel(output, target,
                                           sample["reference"]["level"],
                                           criteriaOfEachLevel)
                accumulateAccuracyEachLayer(
                    correctEachLevel=correctEachLevel,
                    countEachLevel=countEachLevel,
                    predictIndexes=output.argmax(1),
                    target=target,
                    indexesEachLevel=sample["reference"]["level"],
                    n_ary=n_ary)
                if printResult:
                    print(
                        "Target:",
                        indexesToSymbols(target, total_symbols,
                                         sample["reference"]["level"]))
                    print(
                        "Predict:",
                        indexesToSymbols(output.argmax(1), total_symbols,
                                         sample["reference"]["level"]))
                    print()
                # reference to source
                output = model(sample["reference"]["syntax"],
                               sample["source"]["syntax"])
                output = output.view(-1, output.size(2))
                target = sample["source"]["label"].to(cuda)
                loss += computeLossByLevel(output, target,
                                           sample["source"]["level"],
                                           criteriaOfEachLevel)
                accumulateAccuracyEachLayer(
                    correctEachLevel=correctEachLevel,
                    countEachLevel=countEachLevel,
                    predictIndexes=output.argmax(1),
                    target=target,
                    indexesEachLevel=sample["source"]["level"],
                    n_ary=n_ary)
                if printResult:
                    print(
                        "Target:",
                        indexesToSymbols(target, total_symbols,
                                         sample["source"]["level"]))
                    print(
                        "Predict:",
                        indexesToSymbols(output.argmax(1), total_symbols,
                                         sample["source"]["level"]))
                    print()

    loss = loss / (2 * len(dataloader) * dataloader.batch_size * decode_level)

    accuracyEachLevel = []
    for correct, count in zip(correctEachLevel, countEachLevel):
        accuracyEachLevel.append(correct / count)

    print("Evaluation Results. loss: {:.5f} accuracy at each level: {}".format(
        loss, outputAccuracyEachLevel(accuracyEachLevel)))
    model.train()
    acc = torch.tensor(correctEachLevel).sum() / torch.tensor(
        countEachLevel).sum()
    return acc
예제 #22
0
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device_str = "cuda:0" if torch.cuda.is_available() else "cpu"
    print("Device: %s\n" % device_str)
    device = torch.device(device_str)

    # Hyperparameter for Cutmix
    cutmix_beta = 0.3

    # Hyperparameter
    epochs = 100
    lr = 0.01

    train_loader, valid_loader = data.load_data(batch_size=64)
    print("Train samples: %d" % len(train_loader.dataset))
    print("Valid samples: %d" % len(valid_loader.dataset))
    model = model.model()
    model = model.to(device)

    criterion_lss1 = nn.BCELoss()
    criterion_lss2 = nn.KLDivLoss(reduction='batchmean')
    criterion_ce = nn.CrossEntropyLoss()

    optimizer = optim.Adam(model.parameters(), lr=lr)

    time_str = time.strftime("%m_%d-%Hh%Mm%Ss", time.localtime())
    file = open("../log/%s.csv" % time_str, 'w')
    writer = csv.writer(file)
    headers = [
        "train_loss", "train_acc", "train_lsl", "train_lss_1", "train_lss_2",
        "train_lsd", "valid_loss", "valid_acc", "valid_lsl", "valid_lss_1",
        "valid_lss_2", "valid_lsd"
예제 #23
0
epoch_range = trange(epoch, desc='Loss: {1.00000}', leave=True)

model.train().to(device)
# model.load_state_dict(torch.load('modelfiles/yeet.mdl'))

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
norm = t.normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
for e in epoch_range:
    epoch_loss = []
    for image, data_dict in data:
        # c = torch.zeros((1, image.shape[1], image.shape[2]), device=device)
        # image = torch.cat((image, image, image), dim=0)
        image = norm({'image': image})['image']
        optimizer.zero_grad()
        loss = model(image.unsqueeze(0), [data_dict])
        losses = 0
        for key in loss:
            losses += loss[key]
        losses.backward()
        epoch_loss.append(losses.item())
        optimizer.step()

    epoch_range.desc = 'Loss: ' + '{:.5f}'.format(
        torch.tensor(epoch_loss).mean().item())

torch.save(model.state_dict(), 'modelfiles/yeet.mdl')
with torch.no_grad():
    model.eval()
    feature_extractor.to(device).eval()
    reduced_images = torch.zeros((1, 2048), device=device)