예제 #1
0
def main(args):
    """Starting point"""
    logging.basicConfig(format='%(asctime)s -- %(message)s',
                        datefmt='%m/%d/%Y %H:%M:%S',
                        level=logging.INFO)

    img_path = args['IMAGE_PATH']
    checkpoint = args['CHECKPOINT']
    topk = int(args['--top_k'])
    device = torch.device('cuda' if args['--gpu'] else 'cpu')

    category_names = None
    if args['--category_names']:
        with open(args['--category_names'], 'r') as f:
            category_names = json.load(f)

    logging.info('Rebuilding model from checkpoint "%s"...', checkpoint)
    model = common.load_checkpoint(checkpoint)

    logging.info('Predicting top %s class(es) for "%s"', topk, img_path)
    probs, classes = predict(img_path, model, device, topk)

    probs = [p.item() for p in probs]
    classes = [cl.item() for cl in classes]
    if category_names is not None:
        idx_to_class = {v: k for k, v in model.class_to_idx.items()}
        classes = [category_names[idx_to_class[cl]] for cl in classes]
    logging.info('Predicted classes: %s', classes)
    logging.info('Predicted class probabilities: %s', probs)

    logging.info('My job is done, going gently into that good night!')
    return 0
예제 #2
0
def main():
    args = get_args()
    setup_logger('{}/log-train'.format(args.dir), args.log_level)
    logging.info(' '.join(sys.argv))

    if torch.cuda.is_available() == False:
        logging.error('No GPU detected!')
        sys.exit(-1)

    # WARNING(fangjun): we have to select GPU at the very
    # beginning; otherwise you will get trouble later
    kaldi.SelectGpuDevice(device_id=args.device_id)
    kaldi.CuDeviceAllowMultithreading()
    device = torch.device('cuda', args.device_id)

    den_fst = fst.StdVectorFst.Read(args.den_fst_filename)

    # TODO(fangjun): pass these options from commandline
    opts = chain.ChainTrainingOptions()
    opts.l2_regularize = 5e-4
    opts.leaky_hmm_coefficient = 0.1

    den_graph = chain.DenominatorGraph(fst=den_fst, num_pdfs=args.output_dim)

    model = get_chain_model(feat_dim=args.feat_dim,
                            output_dim=args.output_dim,
                            lda_mat_filename=args.lda_mat_filename,
                            hidden_dim=args.hidden_dim,
                            kernel_size_list=args.kernel_size_list,
                            stride_list=args.stride_list)

    start_epoch = 0
    num_epochs = args.num_epochs
    learning_rate = args.learning_rate
    best_objf = -100000

    if args.checkpoint:
        start_epoch, learning_rate, best_objf = load_checkpoint(
            args.checkpoint, model)
        logging.info(
            'loaded from checkpoint: start epoch {start_epoch}, '
            'learning rate {learning_rate}, best objf {best_objf}'.format(
                start_epoch=start_epoch,
                learning_rate=learning_rate,
                best_objf=best_objf))

    model.to(device)

    dataloader = get_egs_dataloader(egs_dir=args.cegs_dir,
                                    egs_left_context=args.egs_left_context,
                                    egs_right_context=args.egs_right_context)

    optimizer = optim.Adam(model.parameters(),
                           lr=learning_rate,
                           weight_decay=args.l2_regularize)

    scheduler = MultiStepLR(optimizer, milestones=[1, 2, 3, 4, 5], gamma=0.5)
    criterion = KaldiChainObjfFunction.apply

    tf_writer = SummaryWriter(log_dir='{}/tensorboard'.format(args.dir))

    best_epoch = start_epoch
    best_model_path = os.path.join(args.dir, 'best_model.pt')
    best_epoch_info_filename = os.path.join(args.dir, 'best-epoch-info')
    try:
        for epoch in range(start_epoch, args.num_epochs):
            learning_rate = scheduler.get_lr()[0]
            logging.info('epoch {}, learning rate {}'.format(
                epoch, learning_rate))
            tf_writer.add_scalar('learning_rate', learning_rate, epoch)

            objf = train_one_epoch(dataloader=dataloader,
                                   model=model,
                                   device=device,
                                   optimizer=optimizer,
                                   criterion=criterion,
                                   current_epoch=epoch,
                                   opts=opts,
                                   den_graph=den_graph,
                                   tf_writer=tf_writer)
            scheduler.step()

            if best_objf is None:
                best_objf = objf
                best_epoch = epoch

            # the higher, the better
            if objf > best_objf:
                best_objf = objf
                best_epoch = epoch
                save_checkpoint(filename=best_model_path,
                                model=model,
                                epoch=epoch,
                                learning_rate=learning_rate,
                                objf=objf)
                save_training_info(filename=best_epoch_info_filename,
                                   model_path=best_model_path,
                                   current_epoch=epoch,
                                   learning_rate=learning_rate,
                                   objf=best_objf,
                                   best_objf=best_objf,
                                   best_epoch=best_epoch)

            # we always save the model for every epoch
            model_path = os.path.join(args.dir, 'epoch-{}.pt'.format(epoch))
            save_checkpoint(filename=model_path,
                            model=model,
                            epoch=epoch,
                            learning_rate=learning_rate,
                            objf=objf)

            epoch_info_filename = os.path.join(args.dir,
                                               'epoch-{}-info'.format(epoch))
            save_training_info(filename=epoch_info_filename,
                               model_path=model_path,
                               current_epoch=epoch,
                               learning_rate=learning_rate,
                               objf=objf,
                               best_objf=best_objf,
                               best_epoch=best_epoch)

    except KeyboardInterrupt:
        # save the model when ctrl-c is pressed
        model_path = os.path.join(args.dir,
                                  'epoch-{}-interrupted.pt'.format(epoch))
        # use a very small objf for interrupted model
        objf = -100000
        save_checkpoint(model_path,
                        model=model,
                        epoch=epoch,
                        learning_rate=learning_rate,
                        objf=objf)

        epoch_info_filename = os.path.join(
            args.dir, 'epoch-{}-interrupted-info'.format(epoch))
        save_training_info(filename=epoch_info_filename,
                           model_path=model_path,
                           current_epoch=epoch,
                           learning_rate=learning_rate,
                           objf=objf,
                           best_objf=best_objf,
                           best_epoch=best_epoch)

    tf_writer.close()
    logging.warning('Done')
예제 #3
0
def train(trace_length, render_eval=False, h_size=512, action_h_size=512,
          target_update_freq=10000, ckpt_freq=500000, summary_freq=1000, eval_freq=10000,
          batch_size=32, env_name='SpaceInvaders', total_iteration=5e7,
          pretrain_steps=50000):

    pretrain_steps = 1000
    # env_name += 'NoFrameskip-v4'
    identity = 'stack={},env={},mod={}'.format(trace_length, env_name, 'adrqn')
    if action_h_size != 512:
        identity += ',action_h={}'.format(action_h_size)

    env = Env(env_name=env_name, skip=4)
    a_size = env.n_actions

    tf.reset_default_graph()
    cell = tf.nn.rnn_cell.LSTMCell(num_units=h_size)
    cellT = tf.nn.rnn_cell.LSTMCell(num_units=h_size)
    mainQN = Qnetwork(h_size, a_size, action_h_size, cell, 'main')
    targetQN = Qnetwork(h_size, a_size, action_h_size, cellT, 'target')
    init = tf.global_variables_initializer()
    updateOps = util.getTargetUpdateOps(tf.trainable_variables())
    saver = tf.train.Saver(max_to_keep=5)

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(init)
    summary_writer = tf.summary.FileWriter('./log/' + identity, sess.graph)

    if util.checkpoint_exists(identity):
        (exp_buf, env, last_iteration, is_done,
         prev_life_count, action, state, S) = util.load_checkpoint(sess, saver, identity)
        if not isinstance(exp_buf, FixedActionTraceBuf):
            new_expbuf = FixedActionTraceBuf(trace_length, buf_length=500000)
            new_expbuf.load_from_legacy(exp_buf)
            del exp_buf
            exp_buf = new_expbuf
        start_time = time.time()
    else:
        exp_buf = FixedActionTraceBuf(trace_length, scenario_size=2500)
        last_iteration = 1 - pretrain_steps
        is_done = True
        action = 0
        prev_life_count = None
        state = None
        sess.run(updateOps)

    summaryOps = tf.summary.merge_all()

    eval_summary_ph = tf.placeholder(tf.float32, shape=(2,), name='evaluation')
    evalOps = (tf.summary.scalar('performance', eval_summary_ph[0]),
               tf.summary.scalar('perform_std', eval_summary_ph[1]))
    online_summary_ph = tf.placeholder(tf.float32, shape=(2,), name='online')
    onlineOps = (tf.summary.scalar('online_performance', online_summary_ph[0]),
                 tf.summary.scalar('online_scenario_length', online_summary_ph[1]))

    for i in range(last_iteration, int(total_iteration)):
        if is_done:
            scen_R, scen_len = exp_buf.flush_scenario()
            if i > 0:
                online_perf_and_length = np.array([scen_R, scen_len])
                online_perf, online_episode_count = sess.run(onlineOps, feed_dict={
                    online_summary_ph: online_perf_and_length})
                summary_writer.add_summary(online_perf, i)
                summary_writer.add_summary(online_episode_count, i)

            s, r, prev_life_count = env.reset()
            S = [s]
            action, state = mainQN.get_action_and_next_state(sess, None, [action], S)

        S = [S[-1]]
        for _ in range(4):
            s, r, is_done, life_count = env.step(action)
            exp_buf.append_trans((
                S[-1], action, r, s,  # not cliping reward (huber loss)
                (prev_life_count and life_count < prev_life_count or is_done)
            ))
            S.append(s)
            prev_life_count = life_count

        action, state = mainQN.get_action_and_next_state(sess, state, [action]*len(S), S)
        if np.random.random() < util.epsilon_at(i):
            action = env.rand_action()

        if not i:
            start_time = time.time()

        if util.Exiting or not i % ckpt_freq:
            util.checkpoint(sess, saver, identity,
                       exp_buf, env, i, is_done,
                       prev_life_count, action, state, S)
            if util.Exiting:
                raise SystemExit

        if i <= 0:
            continue

        if not i % target_update_freq:
            sess.run(updateOps)
            cur_time = time.time()
            print('[{}{}:{}] took {} seconds to {} steps'.format(
                'ADRQN', trace_length, i, cur_time-start_time, target_update_freq), flush=1)
            start_time = cur_time

        # TRAINING STARTS
        state_train = (np.zeros((batch_size, h_size)),) * 2

        trainBatch = exp_buf.sample_traces(batch_size)

        Q1 = sess.run(mainQN.predict, feed_dict={
            mainQN.scalarInput: np.vstack(trainBatch[:, 3]),
            mainQN.actionsInput: trainBatch[:, 1],
            mainQN.trainLength: trace_length,
            mainQN.state_init: state_train,
            mainQN.batch_size: batch_size
        })
        Q2 = sess.run(targetQN.Qout, feed_dict={
            targetQN.scalarInput: np.vstack(trainBatch[:, 3]),
            targetQN.actionsInput: trainBatch[:, 1],
            targetQN.trainLength: trace_length,
            targetQN.state_init: state_train,
            targetQN.batch_size: batch_size
        })
        # end_multiplier = - (trainBatch[:, 4] - 1)
        # doubleQ = Q2[range(batch_size * trace_length), Q1]
        # targetQ = trainBatch[:, 2] + (0.99 * doubleQ * end_multiplier)

        _, summary = sess.run((mainQN.updateModel, summaryOps), feed_dict={
            mainQN.scalarInput: np.vstack(trainBatch[:, 0]),
            mainQN.actionsInput: trainBatch[:, 5],
            mainQN.sample_rewards: trainBatch[:, 2],
            mainQN.sample_terminals: trainBatch[:, 4],
            mainQN.doubleQ: Q2[range(batch_size * trace_length), Q1],
            # mainQN.targetQ: targetQ,
            mainQN.actions: trainBatch[:, 1],
            mainQN.trainLength: trace_length,
            mainQN.state_init: state_train,
            mainQN.batch_size: batch_size
        })

        if not i % summary_freq:
            summary_writer.add_summary(summary, i)
        if not i % eval_freq:
            eval_res = np.array(
                evaluate(sess, mainQN, env_name, is_render=render_eval))
            perf, perf_std = sess.run(
                evalOps, feed_dict={eval_summary_ph: eval_res})
            summary_writer.add_summary(perf, i)
            summary_writer.add_summary(perf_std, i)
    # In the end
    sess.close()
    util.checkpoint(sess, saver, identity)
예제 #4
0
파일: dqn.py 프로젝트: lyu-xg/drqn
def train(stack_length=4,
          render_eval=False,
          h_size=512,
          target_update_freq=10000,
          ckpt_freq=500000,
          summary_freq=1000,
          eval_freq=10000,
          batch_size=32,
          env_name='Pong',
          total_iteration=5e7,
          discovery=False,
          pretrain_steps=50000):
    absolute_start_time = time.time()
    # KICKSTART_EXP_BUF_FILE = 'cache/stack_buf_random_policy_{}.p'.format(pretrain_steps)
    identity = 'stack={},env={},mod={}'.format(stack_length, env_name, 'dqn')

    env = Env(env_name=env_name, skip=4)

    tf.reset_default_graph()

    # loads of side effect! e.g. initialize session, creating graph, etc.
    mainQN = Qnetwork(h_size,
                      env.n_actions,
                      stack_length,
                      'main',
                      train_batch_size=batch_size)

    saver = tf.train.Saver(max_to_keep=5)
    summary_writer = tf.summary.FileWriter('./log/' + identity,
                                           mainQN.sess.graph)

    if util.checkpoint_exists(identity):
        (exp_buf, env, last_iteration, is_done, prev_life_count, action,
         frame_buf) = util.load_checkpoint(mainQN.sess, saver, identity)
        start_time = util.time()
    else:
        frame_buf = FrameBuf(size=stack_length)
        exp_buf, last_iteration = ((StackBuf(size=util.MILLION),
                                    1 - pretrain_steps))
        # if not os.path.isfile(KICKSTART_EXP_BUF_FILE)
        # else (util.load(KICKSTART_EXP_BUF_FILE), 1))
        is_done = True
        prev_life_count = None
        mainQN.update_target_network()

    summaryOps = tf.summary.merge_all()

    eval_summary_ph = tf.placeholder(tf.float32,
                                     shape=(2, ),
                                     name='evaluation')
    evalOps = (tf.summary.scalar('performance', eval_summary_ph[0]),
               tf.summary.scalar('perform_std', eval_summary_ph[1]))
    online_summary_ph = tf.placeholder(tf.float32, shape=(2, ), name='online')
    onlineOps = (tf.summary.scalar('online_performance', online_summary_ph[0]),
                 tf.summary.scalar('online_scenario_length',
                                   online_summary_ph[1]))

    # Main Loop
    for i in range(last_iteration, int(total_iteration)):
        if is_done:
            scen_reward, scen_length = exp_buf.get_and_reset_reward_and_length(
            )
            if i > 0:
                online_perf, online_episode_count = mainQN.sess.run(
                    onlineOps,
                    feed_dict={
                        online_summary_ph: np.array([scen_reward, scen_length])
                    })
                summary_writer.add_summary(online_perf, i)
                summary_writer.add_summary(online_episode_count, i)

            _, prev_life_count = reset(stack_length, env, frame_buf)
            action = mainQN.get_action(list(frame_buf))

        s, r, is_done, life_count = env.step(action,
                                             epsilon=util.epsilon_at(i))
        exp_buf.append_trans((
            list(frame_buf),
            action,
            r,
            list(frame_buf.append(s)),  # not cliping reward (huber loss)
            (prev_life_count and life_count < prev_life_count or is_done)))
        prev_life_count = life_count
        action = mainQN.get_action(list(frame_buf))

        if not i:
            start_time = util.time()
            # util.pickle.dump(exp_buf, open(KICKSTART_EXP_BUF_FILE, 'wb'))

        if i <= 0: continue

        if dicvoery and time.time(
        ) - absolute_start_time > 85500:  # 23 hours and 45 minutes
            util.Exiting = 1

        if util.Exiting or not i % ckpt_freq:
            util.checkpoint(mainQN.sess, saver, identity, exp_buf, env, i,
                            is_done, prev_life_count, action, frame_buf)
            if util.Exiting:
                raise SystemExit

        if not i % target_update_freq:
            mainQN.update_target_network()
            cur_time = util.time()
            print('[{}{}:{}] took {} seconds to {} steps'.format(
                'dqn', stack_length, util.unit_convert(i),
                (cur_time - start_time) // 1, target_update_freq),
                  flush=1)
            start_time = cur_time

        # TRAIN
        trainBatch = exp_buf.sample_batch(batch_size)

        _, summary = mainQN.update_model(*trainBatch,
                                         additional_ops=[summaryOps])

        if not i % summary_freq:
            summary_writer.add_summary(summary, i)
        if not i % eval_freq:
            eval_res = np.array(
                evaluate(mainQN, env_name, is_render=render_eval))
            perf, perf_std = mainQN.sess.run(
                evalOps, feed_dict={eval_summary_ph: eval_res})
            summary_writer.add_summary(perf, i)
            summary_writer.add_summary(perf_std, i)
    # In the end
    util.checkpoint(mainQN.sess, saver, identity)
예제 #5
0
def train(trace_length=10,
          render_eval=False,
          h_size=512,
          target_update_freq=10000,
          ckpt_freq=500000,
          summary_freq=1000,
          eval_freq=10000,
          discovery=False,
          batch_size=32,
          env_name='Pong',
          total_iteration=5e7,
          use_actions=0,
          pretrain_steps=50000,
          num_quant=0):
    # network = dist_Qnetwork if num_quant else Qnetwork
    # env_name += 'NoFrameskip-v4'
    absolute_start_time = time.time()
    model = 'drqn' if not use_actions else 'adrqn'
    if num_quant:
        model = 'dist-' + model
    KICKSTART_EXP_BUF_FILE = 'cache/exp_buf_random_policy_{}_{}_{}.p'.format(
        model, env_name, pretrain_steps)
    ExpBuf = FixedTraceBuf if not use_actions else FixedActionTraceBuf

    model_args = {}
    identity = 'stack={},env={},mod={},h_size={}'.format(
        trace_length, env_name, model, h_size)
    if num_quant:
        identity += ',quantile={}'.format(num_quant)
    if use_actions:
        identity += ',action_dim={}'.format(use_actions)
        model_args['action_hidden_size'] = use_actions
    print(identity)

    env = Env(env_name=env_name, skip=4)

    mainQN = Qnetwork(h_size,
                      env.n_actions,
                      1,
                      'main',
                      train_batch_size=batch_size,
                      model=model,
                      train_trace_length=trace_length,
                      model_kwargs=model_args,
                      num_quant=num_quant)
    saver = tf.train.Saver(max_to_keep=5)

    summary_writer = tf.summary.FileWriter('./log/' + identity,
                                           mainQN.sess.graph)

    if util.checkpoint_exists(identity):
        (exp_buf, env, last_iteration, is_done, prev_life_count, action,
         mainQN.hidden_state,
         S) = util.load_checkpoint(mainQN.sess, saver, identity)
        start_time = time.time()
    else:
        exp_buf = FixedTraceBuf(trace_length, buf_length=500000)
        last_iteration = 1 - pretrain_steps
        if os.path.isfile(KICKSTART_EXP_BUF_FILE):
            print('Filling buffer with random episodes on disk.')
            exp_buf, last_iteration = util.load(KICKSTART_EXP_BUF_FILE), 1
        is_done = True
        prev_life_count = None
        mainQN.update_target_network()

    summaryOps = tf.summary.merge_all()

    eval_summary_ph = tf.placeholder(tf.float32,
                                     shape=(4, ),
                                     name='evaluation')
    evalOps = (tf.summary.scalar('performance', eval_summary_ph[0]),
               tf.summary.scalar('perform_std', eval_summary_ph[1]),
               tf.summary.scalar('flicker_performance', eval_summary_ph[2]),
               tf.summary.scalar('flicker_perform_std', eval_summary_ph[3]))
    online_summary_ph = tf.placeholder(tf.float32, shape=(2, ), name='online')
    onlineOps = (tf.summary.scalar('online_performance', online_summary_ph[0]),
                 tf.summary.scalar('online_scenario_length',
                                   online_summary_ph[1]))

    for i in range(last_iteration, int(total_iteration)):
        if is_done:
            scen_R, scen_L = exp_buf.flush_scenario()
            if i > 0:
                online_perf_and_length = np.array([scen_R, scen_L])
                online_perf, online_episode_count = mainQN.sess.run(
                    onlineOps,
                    feed_dict={online_summary_ph: online_perf_and_length})
                summary_writer.add_summary(online_perf, i)
                summary_writer.add_summary(online_episode_count, i)

            S, r, prev_life_count = env.reset()
            S = np.reshape(S, (1, 84, 84))
            mainQN.reset_hidden_state()

        action, _ = mainQN.get_action_stateful(S, prev_a=0)
        S_new, r, is_done, life_count = env.step(action,
                                                 epsilon=util.epsilon_at(i))
        S_new = np.reshape(S_new, (1, 84, 84))
        exp_buf.append_trans((
            S,
            action,
            r,
            S_new,  # not cliping reward (huber loss)
            (prev_life_count and life_count < prev_life_count or is_done)))
        S = S_new
        prev_life_count = life_count

        if not i:
            util.save(exp_buf, KICKSTART_EXP_BUF_FILE)

        if discovery and time.time(
        ) - absolute_start_time > 85500:  # 23 hours and 45 minutes
            util.Exiting = 1

        if util.Exiting or not i % ckpt_freq:
            util.checkpoint(mainQN.sess, saver, identity, exp_buf, env, i,
                            is_done, prev_life_count, action,
                            mainQN.hidden_state, S)
            if util.Exiting:
                raise SystemExit

        if i < 0: continue

        # TRAIN
        _, summary = mainQN.update_model_stateful(
            *exp_buf.sample_traces(batch_size), addtional_ops=[summaryOps])

        # Summary
        if not i % summary_freq:
            summary_writer.add_summary(summary, i)

        # Target Update
        if not i % target_update_freq:
            mainQN.update_target_network()
            cur_time = time.time()
            print(i, identity)
            try:
                print('[{}:{}:{}K] took {} seconds to {} steps'.format(
                    model, env_name, util.unit_convert(i),
                    int(cur_time - start_time), target_update_freq),
                      flush=1)
            except:
                pass
            start_time = cur_time

        # Evaluate
        if not i % eval_freq:
            eval_res = np.array(
                evaluate(mainQN, env_name, is_render=render_eval))
            eval_vals = mainQN.sess.run(evalOps,
                                        feed_dict={eval_summary_ph: eval_res})
            for v in eval_vals:
                summary_writer.add_summary(v, i)

    util.checkpoint(mainQN.sess, saver, identity)
예제 #6
0
def main():
    args = get_args()
    setup_logger('{}/log-inference'.format(args.dir), args.log_level)
    logging.info(' '.join(sys.argv))

    if torch.cuda.is_available() == False:
        logging.warning('No GPU detected! Use CPU for inference.')
        device = torch.device('cpu')
    else:
        device = torch.device('cuda', args.device_id)

    model = get_chain_model(feat_dim=args.feat_dim,
                            output_dim=args.output_dim,
                            lda_mat_filename=args.lda_mat_filename,
                            hidden_dim=args.hidden_dim,
                            kernel_size_list=args.kernel_size_list,
                            stride_list=args.stride_list)

    load_checkpoint(args.checkpoint, model)

    model.to(device)
    model.eval()

    specifier = 'ark,scp:{filename}.ark,{filename}.scp'.format(
        filename=os.path.join(args.dir, 'nnet_output'))

    if args.save_as_compressed:
        Writer = kaldi.CompressedMatrixWriter
        Matrix = kaldi.CompressedMatrix
    else:
        Writer = kaldi.MatrixWriter
        Matrix = kaldi.FloatMatrix

    writer = Writer(specifier)

    dataloader = get_feat_dataloader(
        feats_scp=args.feats_scp,
        model_left_context=args.model_left_context,
        model_right_context=args.model_right_context,
        batch_size=32)

    for batch_idx, batch in enumerate(dataloader):
        key_list, padded_feat, output_len_list = batch
        padded_feat = padded_feat.to(device)
        with torch.no_grad():
            nnet_output, _ = model(padded_feat)

        num = len(key_list)
        for i in range(num):
            key = key_list[i]
            output_len = output_len_list[i]
            value = nnet_output[i, :output_len, :]
            value = value.cpu()

            m = kaldi.SubMatrixFromDLPack(to_dlpack(value))
            m = Matrix(m)
            writer.Write(key, m)

        if batch_idx % 10 == 0:
            logging.info('Processed batch {}/{} ({:.6f}%)'.format(
                batch_idx, len(dataloader),
                float(batch_idx) / len(dataloader) * 100))

    writer.Close()
    logging.info('pseudo-log-likelihood is saved to {}'.format(
        os.path.join(args.dir, 'nnet_output.scp')))
예제 #7
0
category_names = results.category_names
gpu = results.gpu
top_k = int(float(top_k))


if gpu and  not torch.cuda.is_available() :
    print("There is no a gpu device available")


message_cuda = "cuda is available" if torch.cuda.is_available() else "cuda is not available"

print(message_cuda)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = load_checkpoint(checkpoint)

model.to(device);

class_to_idx = model.class_to_idx

probs, classes = predict(path_to_image, model, top_k)

if not category_names:
    print(classes)
    print(probs)

else:

    name_classes = []