Exemplo n.º 1
0
def train(config):
    model = cls.Cls(config, loss_mode=sys.argv[1])
    model.initialize_weights()

    max_training_step = config.max_training_step
    best_acc = 0
    endurance = 0
    i = 0
    logger.info('lambda1: {}'.format(config.lambda1_stud))
    while i < max_training_step and endurance < config.max_endurance_stud:
        train_loss, train_acc = model.train()
        if i % config.valid_frequency_stud == 0:
            endurance += 1
            valid_loss, valid_acc, _, _ = model.valid()
            logger.info('====Step: {}===='.format(i))
            logger.info('train_loss: {}, valid_loss: {}'\
                        .format(train_loss, valid_loss))
            logger.info('train_acc: {}, valid_acc: {}'\
                        .format(train_acc, valid_acc))
            if valid_acc > best_acc:
                best_acc = valid_acc
                _, test_acc, _, _ = model.valid(model.test_dataset)
                endurance = 0
        i += 1

    logger.info('lambda1: {}, lambda2: {}'.format(config.lambda1_stud,
                                                  config.lambda2_stud))
    logger.info('valid_acc: {}'.format(best_acc))
    logger.info('test_acc: {}'.format(test_acc))
    return test_acc
Exemplo n.º 2
0
def train(config):
    g = tf.Graph()
    gpu_options = tf.GPUOptions(allow_growth=True)
    configProto = tf.ConfigProto(gpu_options=gpu_options)
    sess = tf.InteractiveSession(config=configProto, graph=g)

    model = cls.Cls(config, g, loss_mode=sys.argv[1])
    sess.run(model.init)

    max_training_step = config.max_training_step
    best_acc = 0
    endurance = 0
    i = 0
    while i < max_training_step and endurance < config.max_endurance_stud:
        train_loss, train_acc = model.train(sess)
        if i % config.valid_frequence_stud == 0:
            endurance += 1
            valid_loss, valid_acc, _, _ = model.valid(sess)
            #logger.info('====Step: {}===='.format(i))
            #logger.info('train_loss: {}, train_acc: {}'\
            #            .format(train_loss, train_acc))
            #logger.info('valid_loss: {}, valid_acc: {}'\
            #            .format(valid_loss, valid_acc))
            if valid_acc > best_acc:
                best_acc = valid_acc
                _, test_acc, _, _ = model.valid(sess, model.test_dataset)
                endurance = 0
        i += 1

    logger.info('lambda1: {}, lambda2: {}'.format(config.lambda1_stud,
                                                  config.lambda2_stud))
    logger.info('valid_acc: {}'.format(best_acc))
    logger.info('test_acc: {}'.format(test_acc))
    return test_acc
Exemplo n.º 3
0
    def __init__(self, config, args):
        self.config = config

        #hostname = config.hostname
        #hostname = '-'.join(hostname.split('.')[0:2])
        #datetime = strftime('%m-%d-%H-%M', gmtime())
        #exp_name = '{}_{}_{}'.format(hostname, datetime, args.exp_name)

        exp_name = args.exp_name
        logger.info('exp_name: {}'.format(exp_name))

        if config.rl_method == 'reinforce':
            self.model_ctrl = controller_reinforce.Controller(config, exp_name+'_ctrl')
        elif config.rl_method == 'ppo':
            self.model_ctrl = controller_ppo.MlpPPO(config, exp_name+'_ctrl')


        if args.task_name == 'reg':
            from models import reg
            self.model_task = reg.Reg(config, exp_name+'_reg')
        elif args.task_name == 'cls':
            from models import cls
            self.model_task = cls.Cls(config, exp_name+'_cls')
        elif args.task_name == 'gan':
            from models import gan
            self.model_task = gan.Gan(config, exp_name+'_gan')
        elif args.task_name == 'gan_cifar10':
            from models import gan_cifar10
            self.model_task = gan_cifar10.Gan_cifar10(config,
                                                      exp_name+'_gan_cifar10')
        elif args.task_name == 'nmt':
            from models import nmt
            self.model_task = nmt.Nmt(config, exp_name+'_nmt')
        else:
            raise NotImplementedError
Exemplo n.º 4
0
    def __init__(self, config, exp_name=None, arch=None):
        self.config = config

        hostname = socket.gethostname()
        hostname = '-'.join(hostname.split('.')[0:2])
        datetime = strftime('%m-%d-%H-%M', gmtime())
        if not exp_name:
            exp_name = '{}_{}'.format(hostname, datetime)
        logger.info('exp_name: {}'.format(exp_name))

        self.model_ctrl = controller.Controller(config, exp_name+'_ctrl')
        if config.student_model_name == 'toy':
            self.model_stud = toy.Toy(config, exp_name+'_reg')
        elif config.student_model_name == 'cls':
            self.model_stud = cls.Cls(config, exp_name+'_cls')
        elif config.student_model_name == 'gan':
            self.model_stud = gan.Gan(config, exp_name+'_gan', arch=arch)
        elif config.student_model_name == 'gan_grid':
            self.model_stud = gan_grid.Gan_grid(config, exp_name+'_gan_grid')
        elif config.student_model_name == 'gan_cifar10':
            self.model_stud = gan_cifar10.Gan_cifar10(config,
                                                      exp_name+'_gan_cifar10')
        else:
            raise NotImplementedError