Exemplo n.º 1
0
    def get_optimzer_lr(self, global_step, step_factor):
        stepsize = int(cfg.TRAIN.STEPSIZE * step_factor)
        gamma = cfg.TRAIN.GAMMA
        epoch_iters = get_epoch_iters(self.net.model_name)
        stepsize = epoch_iters * 2

        lr = tf.train.exponential_decay(cfg.TRAIN.LEARNING_RATE * 10,
                                        global_step,
                                        stepsize,
                                        gamma,
                                        staircase=True)
        optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)

        if self.net.model_name.__contains__('cosine'):
            print('cosine =========')
            first_decay_steps = epoch_iters * 10  # 2 epoches
            from tensorflow.python.training.learning_rate_decay import cosine_decay_restarts
            lr = cosine_decay_restarts(cfg.TRAIN.LEARNING_RATE * 10,
                                       global_step,
                                       first_decay_steps,
                                       t_mul=2.0,
                                       m_mul=0.9,
                                       alpha=cfg.TRAIN.LEARNING_RATE * 0.1)
            optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
        elif self.net.model_name.__contains__('zsrare'):  #rare first
            lr = tf.train.exponential_decay(cfg.TRAIN.LEARNING_RATE * 10,
                                            global_step,
                                            int(cfg.TRAIN.STEPSIZE * 2),
                                            gamma,
                                            staircase=True)
            optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
        elif self.net.model_name.__contains__('zsnrare'):  # non rare first
            lr = tf.train.exponential_decay(cfg.TRAIN.LEARNING_RATE * 10,
                                            global_step,
                                            int(cfg.TRAIN.STEPSIZE *
                                                step_factor),
                                            gamma,
                                            staircase=True)
            optimizer = tf.train.MomentumOptimizer(lr, cfg.TRAIN.MOMENTUM)
        return lr, optimizer
Exemplo n.º 2
0
    def snapshot(self, sess, iter):
        epoch_iters = get_epoch_iters(self.net.model_name)
        if self.net.model_name.__contains__(
                'zs_'):  # default zs, include rare, nonrare, unseen
            snapshot_iters = cfg.TRAIN.SNAPSHOT_ITERS
        elif self.net.model_name.__contains__('zsrare_'):
            snapshot_iters = cfg.TRAIN.SNAPSHOT_ITERS
        elif self.net.model_name.__contains__('zsnrare'):
            snapshot_iters = cfg.TRAIN.SNAPSHOT_ITERS
        else:
            snapshot_iters = cfg.TRAIN.SNAPSHOT_ITERS * 5

        if (iter + 1) % snapshot_iters == 0 and iter != 0:
            if not os.path.exists(self.output_dir):
                os.makedirs(self.output_dir)

            # Store the model snapshot
            filename = 'HOI' + '_iter_{:d}'.format(iter + 1) + '.ckpt'
            filename = os.path.join(self.output_dir, filename)
            self.saver.save(sess, filename)
            print('Wrote snapshot to: {:s}'.format(filename),
                  iter / snapshot_iters)
Exemplo n.º 3
0
            args.Restore_flag = -1
        elif args.model.__contains__(
                'cosine') and not args.model.__contains__('s0'):
            # This is for fine-tuning
            args.Restore_flag = -7
        elif args.model.__contains__('unique_weights'):
            args.Restore_flag = 6
    if args.model.__contains__('unique_weights'):
        args.Restore_flag = 6

    if args.Restore_flag == -1:
        ckpt = tf.train.get_checkpoint_state(output_dir)
        print(output_dir, ckpt.model_checkpoint_path)
        init_step = ckpt.model_checkpoint_path.split('/')[-1].split('_')[-1]
        init_step = int(init_step.replace('.ckpt', ''))
        start_epoch = init_step // get_epoch_iters(args.model)
    augment_type = get_augment_type(args.model)

    if args.model.__contains__('res101'):
        os.environ['DATASET'] = 'HICO_res101'
        from networks.HOI import HOI
        net = HOI(model_name=args.model)
    else:
        from networks.HOI import HOI
        net = HOI(model_name=args.model)

    pattern_type = 0
    zero_shot_type = get_zero_shot_type(args.model)
    large_neg_for_ho = False
    assert args.model.__contains__('batch')
    logger.info("large neg: %".format(large_neg_for_ho))
Exemplo n.º 4
0
    def train_model_stepwise_inner(self, D_loss, g_loss, iter, lr, max_iters,
                                   sess, timer, train_op, train_op_g):
        while iter < max_iters + 1:
            timer.tic()

            total_loss = 0
            fake_total_loss = 0
            #
            save_iters = 50000
            epoch_stride = 0
            if self.net.model_name.__contains__('_s1_'):
                # This is for fine-tuning the fabricator in step-wise optimization
                epoch_stride = 1
            save_iters = get_epoch_iters(self.net.model_name)

            if iter < save_iters * epoch_stride:
                if (iter % cfg.TRAIN.SUMMARY_INTERVAL == 0) or (iter < 20):
                    # Compute the graph with summary
                    fake_total_loss, _, summary, image_id = sess.run([
                        g_loss,
                        train_op_g,
                        self.net.summary_op,
                        self.net.image_id,
                    ])

                    # total_loss, summary = self.net.train_step_with_summary(sess, blobs, lr.eval(), train_op)
                    self.writer.add_summary(summary, float(iter))
                else:
                    # Compute the graph without summary
                    fake_total_loss, _, image_id = sess.run([
                        g_loss,
                        train_op_g,
                        self.net.image_id,
                    ])
            else:
                if (iter % cfg.TRAIN.SUMMARY_INTERVAL == 0) or (iter < 20):

                    # Compute the graph with summary
                    total_loss, _, summary, image_id = sess.run([
                        D_loss,
                        train_op,
                        self.net.summary_op,
                        self.net.image_id,
                    ])

                    # total_loss, summary = self.net.train_step_with_summary(sess, blobs, lr.eval(), train_op)
                    # total_loss, summary = self.net.train_step_with_summary(sess, blobs, lr.eval(), train_op)
                    self.writer.add_summary(summary, float(iter))

                else:
                    # Compute the graph without summary
                    total_loss, _, image_id = sess.run([
                        D_loss,
                        train_op,
                        self.net.image_id,
                    ])

            timer.toc()
            # print(image_id)
            # Display training information
            if iter % cfg.TRAIN.DISPLAY == 0:
                if type(image_id) == tuple:
                    image_id = image_id[0]
                print(
                    'iter: {:d} / {:d}, im_id: {:d}, loss: {:.6f}, G: {:.6f} lr: {:f}, speed: {:.3f} s/iter'
                    .format(iter, max_iters, image_id, total_loss,
                            fake_total_loss, lr.eval(), timer.average_time),
                    end='\n',
                    flush=True)
            # Snapshotting
            t_iter = iter
            self.snapshot(sess, t_iter)

            iter += 1
Exemplo n.º 5
0
    def train_model_stepwise_inner(self, D_loss, g_loss, iter, lr, max_iters, sess, timer, train_op, train_op_g):
        while iter < max_iters + 1:
            timer.tic()

            total_loss = 0
            fake_total_loss = 0
            #
            save_iters = 50000
            epoch_stride = 5
            if self.net.model_name.__contains__('_s3_'):
                epoch_stride = 3
            elif self.net.model_name.__contains__('_s1_'):
                epoch_stride = 1
            elif self.net.model_name.__contains__('_s05_'):
                epoch_stride = 0.5
            elif self.net.model_name.__contains__('_s0_'):
                epoch_stride = 0
            save_iters = get_epoch_iters(self.net.model_name)
            if iter < save_iters * epoch_stride and not self.net.model_name.__contains__('_reload'):
                if (iter % cfg.TRAIN.SUMMARY_INTERVAL == 0) or (iter < 20):
                    # Compute the graph with summary
                    fake_total_loss, _, summary, image_id = sess.run(
                        [g_loss, train_op_g, self.net.summary_op, self.net.image_id, ])

                    # total_loss, summary = self.net.train_step_with_summary(sess, blobs, lr.eval(), train_op)
                    self.writer.add_summary(summary, float(iter))
                else:
                    # Compute the graph without summary
                    fake_total_loss, _, image_id = sess.run([g_loss, train_op_g, self.net.image_id, ])
                if iter + 1 == save_iters * epoch_stride:
                    iter = save_iters - 1
            else:
                if (iter % cfg.TRAIN.SUMMARY_INTERVAL == 0) or (iter < 20):

                    # Compute the graph with summary
                    total_loss, _, summary, image_id = sess.run(
                        [D_loss, train_op, self.net.summary_op, self.net.image_id, ])

                    # total_loss, summary = self.net.train_step_with_summary(sess, blobs, lr.eval(), train_op)
                    self.writer.add_summary(summary, float(iter))

                else:
                    # Compute the graph without summary
                    total_loss, _, image_id = sess.run([D_loss, train_op, self.net.image_id, ])

            timer.toc()
            # print(image_id)
            # Display training information
            if iter % cfg.TRAIN.DISPLAY == 0:
                # if type(image_id) == tuple:
                #     image_id = image_id[0]
                # print(image_id)
                print('iter: {:d} / {:d}, im_id: {:d}, loss: {:.6f}, G: {:.6f} lr: {:f}, speed: {:.3f} s/iter'.format(
                    iter, max_iters, image_id[0], total_loss, fake_total_loss, lr.eval(), timer.average_time), end='\n',
                    flush=True)
            # print('\rmodel: {} im_detect: {:d}/{:d}  {:d}, {:.3f}s'.format(net.model_name, count, 15765, _image_id,
            #                                                                _t['im_detect'].average_time), end='',
            #       flush=True)
            # Snapshotting
            t_iter = iter
            # if iter == 0 and self.net.model_name.__contains__('_pret'):
            #     t_iter = 1000000
            self.snapshot(sess, t_iter)

            iter += 1