示例#1
0
def setup_training_tools_0001(this_net, config_data, verbose=0, tab_level=0):
    pm.printvm("setup_training_tools_0001()." % (),
               tab_level=tab_level,
               verbose=verbose,
               verbose_threshold=50)

    criterion = nn.CrossEntropyLoss()
    from utils.optimizer import get_optimizer
    optimizer = get_optimizer(this_net,
                              config_data,
                              verbose=verbose,
                              tab_level=tab_level + 1)
    return criterion, optimizer
def train():
    if args.dataset == 'COCO':
        print("WARNING: Using default COCO dataset_root because " +
              "--dataset_root was not specified.")
        exit()
    elif args.dataset == 'VOC':
        if args.net_type == 'tdsod':
            print('tdsod setting')
            cfg = TDSOD_voc
        else:
            cfg = voc
        dataset = VOCDetection(root=VOC_ROOT,
                               transform=SSDAugmentation(
                                   cfg['min_dim'], MEANS))

    if args.net_type == 'tdsod':
        print('Build TDSOD')
        net, head = build_tdsod(phase='train')
    elif args.net_type == 'qssd':
        print('Build SSD')
        net, head = build_ssd(phase='train')
    else:
        print('we only support tdsod and qssd. Thank you')
        exit()

    if args.cuda:
        if num_gpu > 1:
            net = torch.nn.DataParallel(net)
            head = torch.nn.DataParallel(head)
        cudnn.benchmark = True

    print(net)
    print(head)
    print('the number of model parameters: {}'.format(
        sum([p.data.nelement() for p in net.parameters()])))
    print('training step: ', cfg['lr_steps'])
    print('max step: ', cfg['max_iter'])

    if args.resume:
        print('Resuming training, loading {}...'.format(args.resume))
        net.load_weights(args.resume)

    if args.cuda:
        net = net.cuda()
        head = head.cuda()

    if args.scratch:
        if args.net_type == 'tdsod':
            print('Initializing weights for training from SCRATCH - TDSOD')
            net.apply(weights_init)
            head.apply(weights_init)
        elif args.net_type == 'qssd':
            print('Initializing weights for training from SCRATCH - QSSD')
            net.extras.apply(weights_init)
            head.apply(weights_init)
    else:
        print('This code only support scratch mode')
        exit()

    optimizer = get_optimizer(args.optimizer,
                              list(net.parameters()) + list(head.parameters()),
                              args=args)
    criterion = MultiBoxLoss(cfg['num_classes'], 0.5, True, 0, True, 3, 0.5,
                             False, args.cuda)

    net.train()
    head.train()

    # loss counters
    loc_loss = 0
    conf_loss = 0
    epoch = 0

    if not os.path.exists('weights'):
        os.makedirs('weights')
    print('Loading the dataset...')

    epoch_size = len(
        dataset) // args.batch_size  # epoch_Size가 아니라 iteration_size일듯
    print('Dataset size, Total epoch:', len(dataset),
          (cfg['max_iter'] - args.start_iter) / epoch_size)
    print('Training SSD on:', dataset.name)
    print('Using the specified args:')
    print(args)

    step_index = 0

    data_loader = data.DataLoader(dataset,
                                  args.batch_size,
                                  num_workers=args.num_workers,
                                  shuffle=True,
                                  collate_fn=detection_collate,
                                  pin_memory=True)

    # assign fp warmup
    # create batch iterator
    if args.quant == True:
        if args.warmup == True:
            print('start optimizer warm-up')
            batch_iterator = None
            start_time = time.time()
            for iteration in range(args.start_iter, 2 * epoch_size):
                #     print(not batch_iterator, iteration % epoch_size)
                t0 = time.time()
                if (not batch_iterator) or (iteration % epoch_size == 0):
                    batch_iterator = iter(data_loader)
                if args.visdom and iteration != 0 and (iteration % epoch_size
                                                       == 0):
                    # reset epoch loss counters
                    loc_loss = 0
                    conf_loss = 0

                if iteration in cfg['lr_steps']:
                    step_index += 1
                    adjust_learning_rate(optimizer, args.gamma, step_index)

                # load train data
                images, targets = next(batch_iterator)

                if args.cuda:
                    images = Variable(images.cuda())
                    with torch.no_grad():
                        targets = [Variable(ann.cuda()) for ann in targets]
                else:
                    images = Variable(images)
                    with torch.no_grad():
                        targets = [Variable(ann) for ann in targets]
                # forward
                t1 = time.time()
                out = head(net(images))

                # backprop
                optimizer.zero_grad()
                loss_l, loss_c = criterion(out, targets)
                loss = loss_l + loss_c
                loss.backward()
                optimizer.step()
                t2 = time.time()
                loc_loss += loss_l.item()
                conf_loss += loss_c.item()

                if iteration % 100 == 0:
                    current_LR = get_learning_rate(optimizer)[0]
                    print(
                        'iter ' + repr(iteration) + '|| LR: ' +
                        repr(current_LR) + ' || Loss: %.4f ||' % (loss.item()),
                        'batch/forw. time: %.2f' % (t2 - t0),
                        '%.2f:' % (t2 - t1), 'avg. time: %.4f' %
                        ((time.time() - start_time) / (iteration + 1)))
            start_iter = iteration
        else:
            start_iter = 0

        # quantization on
        if num_gpu > 1:
            net.module.fuse_model()
            net.module.qconfig = torch.quantization.get_default_qat_qconfig(
                'qnnpack')
            torch.quantization.prepare_qat(net.module, inplace=True)
            print('quant on')
        else:
            net.fuse_model()
            net.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
            torch.quantization.prepare_qat(net, inplace=True)
            print('quant on')
    else:
        print('run FP only model')
        start_iter = 0

    # create batch iterator
    batch_iterator = None
    start_time = time.time()
    for iteration in range(start_iter, cfg['max_iter']):

        t0 = time.time()
        if (not batch_iterator) or (iteration % epoch_size == 0):
            batch_iterator = iter(data_loader)
        if args.visdom and iteration != 0 and (iteration % epoch_size == 0):
            # reset epoch loss counters
            loc_loss = 0
            conf_loss = 0
            epoch += 1

        if iteration in cfg['lr_steps']:
            step_index += 1
            adjust_learning_rate(optimizer, args.gamma, step_index)

        # load train data
        images, targets = next(batch_iterator)

        if args.cuda:
            images = Variable(images.cuda())
            with torch.no_grad():
                targets = [Variable(ann.cuda()) for ann in targets]
        else:
            images = Variable(images)
            with torch.no_grad():
                targets = [Variable(ann) for ann in targets]
        # forward
        t1 = time.time()
        out = head(net(images))

        # backprop
        optimizer.zero_grad()
        loss_l, loss_c = criterion(out, targets)
        loss = loss_l + loss_c
        loss.backward()
        optimizer.step()
        t2 = time.time()
        loc_loss += loss_l.item()
        conf_loss += loss_c.item()

        if iteration % 100 == 0:
            current_LR = get_learning_rate(optimizer)[0]

            print(
                'iter ' + repr(iteration) + '|| LR: ' + repr(current_LR) +
                ' || Loss: %.4f ||' % (loss.item()),
                'batch/forw. time: %.2f' % (t2 - t0), '%.2f:' % (t2 - t1),
                'avg. time: %.4f' % ((time.time() - start_time) /
                                     (iteration + 1)))

        if iteration != 0 and iteration % args.save_iter == 0:
            print('Saving state and evaluation at iter:', iteration)
            if num_gpu > 1:
                torch.save(net.module.state_dict(),
                           'weights/ssd300_f_' + repr(iteration) + '.pth')
                torch.save(head.module.state_dict(),
                           'weights/ssd300_h_' + repr(iteration) + '.pth')
            else:
                torch.save(net.state_dict(),
                           'weights/ssd300_f_' + repr(iteration) + '.pth')
                torch.save(head.state_dict(),
                           'weights/ssd300_h_' + repr(iteration) + '.pth')
            mean_AP = evaluator(args.net_type,
                                'VOC0712',
                                'weights/ssd300_f_' + repr(iteration) + '.pth',
                                'weights/ssd300_h_' + repr(iteration) + '.pth',
                                cuda=args.cuda,
                                quant=args.quant,
                                verbose=False)

    if num_gpu > 1:
        torch.save(net.module.state_dict(), 'weights/ssd300_f_final.pth')
        torch.save(head.module.state_dict(), 'weights/ssd300_f_final.pth')
    else:
        torch.save(net.state_dict(), 'weights/ssd300_f_final.pth')
        torch.save(head.state_dict(), 'weights/ssd300_h_final.pth')
def training(args):
    
    # DIRECTORY FOR CKPTS and META FILES
    # ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    ROOT_DIR = '/media/data/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'train.tfrecords')
    VAL_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'val.tfrecords')
    CKPT_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        args.ckpt_folder_name,
        '/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer(
            [TRAIN_REC_PATH], num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size)

        val_queue = tf.train.string_input_producer(
            [VAL_REC_PATH], num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size)

        with tf.variable_scope('bipn'):
            print('TRAIN FRAMES (first):')
            train_rec_iFrames = bipn.build_bipn(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True)

        with tf.variable_scope('bipn', reuse=tf.AUTO_REUSE):
            print('VAL FRAMES (first):')
            val_rec_iFrames = bipn.build_bipn(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                is_training=False)
            
        print('Model parameters:{}'.format(
            count_parameters()))

        # Weights should be kept locally ~ 500 MB space
        with tf.variable_scope('vgg16'):
            train_iFrames_features = vgg16.build_vgg16(
                train_iFrames, end_point='conv4_3').features
        with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
            train_rec_iFrames_features = vgg16.build_vgg16(
                train_rec_iFrames, end_point='conv4_3').features


        if args.perceptual_loss_weight:
            # Weights should be kept locally ~ 500 MB space
            with tf.variable_scope('vgg16'):
                train_iFrames_features = vgg16(
                    train_iFrames,
                    end_point='conv4_3')
            with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
                train_rec_iFrames_features = vgg16(
                    train_rec_iFrames,
                    end_point='conv4_3')

        # DEFINE METRICS
        if args.loss_id == 0:
            train_loss = huber_loss(
                train_iFrames, train_rec_iFrames,
                delta=1.)
            val_loss = huber_loss(
                val_iFrames, val_rec_iFrames,
                delta=1.)

        elif args.loss_id == 1:
            train_loss = l2_loss(
                train_iFrames, train_rec_iFrames)
            val_loss = l2_loss(
                val_iFrames, val_rec_iFrames)

        total_train_loss = train_loss
        tf.summary.scalar('train_l2_loss', train_loss)
        tf.summary.scalar('total_val_l2_loss', val_loss)

       if args.perceptual_loss_weight: 
            train_perceptual_loss = perceptual_loss(
                train_iFrames_features,
                train_rec_iFrames_features)

            tf.summary.scalar('train_perceptual_loss',\
                train_perceptual_loss)

            total_train_loss += train_perceptual_loss\
                * args.perceptual_loss_weight

        # SUMMARIES
        tf.summary.scalar('total_train_loss',\
            total_train_loss)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            CKPT_PATH + 'train',
            sess.graph)

        # DEFINE OPTIMIZER
        optimizer = get_optimizer(
            train_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate,
            use_batch_norm=True)

        init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(
            coord=coord)

        # START TRAINING HERE
        try:
            for iteration in range(args.train_iters):
                _, t_summ, t_loss = sess.run(
                    [optimizer, merged, total_train_loss])

                train_writer.add_summary(t_summ, iteration)
                print('Iter:{}/{}, Train Loss:{}'.format(
                    iteration,
                    args.train_iters,
                    t_loss))

                if iteration % args.val_every == 0:
                    v_loss = sess.run(val_loss)
                    print('Iter:{}, Val Loss:{}'.format(
                        iteration,
                        v_loss))

                if iteration % args.save_every == 0:
                    saver.save(
                        sess,
                        CKPT_PATH + 'iter:{}_val:{}'.format(
                            str(iteration),
                            str(round(v_loss, 3))))

                if iteration % args.plot_every == 0:
                    start_frames, end_frames, mid_frames,\
                        rec_mid_frames = sess.run(
                            [train_fFrames, train_lFrames,\
                                train_iFrames,\
                                train_rec_iFrames])

                    visualize_frames(
                        start_frames,
                        end_frames,
                        mid_frames,
                        rec_mid_frames,
                        iteration=iteration,
                        save_path=os.path.join(
                            CKPT_PATH,
                            'plots/'))

        except Exception as e:
            coord.request_stop(e)
def training(args):
    
    # DIRECTORY FOR CKPTS and META FILES
    ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'train.tfrecords')
    VAL_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'val.tfrecords')
    CKPT_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'skip_merge_conv_summary_image_tanh_separate_bipn_wd_l2_adam_1e-3/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer(
            [TRAIN_REC_PATH], num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size)

        val_queue = tf.train.string_input_producer(
            [VAL_REC_PATH], num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size)

        with tf.variable_scope('separate_bipn'):
            print('TRAIN FRAMES (first):')
            train_rec_iFrames = skip_merge_conv_separate_encoder_bipn.build_bipn(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True)

        with tf.variable_scope('separate_bipn', reuse=tf.AUTO_REUSE):
            print('VAL FRAMES (first):')
            val_rec_iFrames = skip_merge_conv_separate_encoder_bipn.build_bipn(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                is_training=False)
            
        print('Model parameters:{}'.format(
            count_parameters()))

        # DEFINE METRICS
        if args.loss_id == 0:
            train_loss = huber_loss(
                train_iFrames, train_rec_iFrames,
                delta=1.)
            val_loss = huber_loss(
                val_iFrames, val_rec_iFrames,
                delta=1.)

        elif args.loss_id == 1:
            train_loss = l2_loss(
                train_iFrames, train_rec_iFrames)
            val_loss = l2_loss(
                val_iFrames, val_rec_iFrames) 

        if args.weight_decay:
            decay_loss = ridge_weight_decay(
                tf.trainable_variables())
            train_loss += args.weight_decay * decay_loss

        # SUMMARIES
        tf.summary.scalar('train_loss', train_loss)
        tf.summary.scalar('val_loss', val_loss)

        with tf.contrib.summary.\
            record_summaries_every_n_global_steps(
                n=args.summary_image_every):
            summary_true, summary_fake = visualize_tensorboard(
                train_fFrames,
                train_lFrames,
                train_iFrames,
                train_rec_iFrames,
                num_plots=3)
            tf.summary.image('true frames', summary_true)
            tf.summary.image('fake frames', summary_fake)

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            CKPT_PATH + 'train',
            sess.graph)

        # DEFINE OPTIMIZER
        optimizer = get_optimizer(
            train_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate,
            use_batch_norm=True)

        init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(
            coord=coord)

        # START TRAINING HERE
        try:
            for iteration in range(args.train_iters):
                _, t_summ, t_loss = sess.run(
                    [optimizer, merged, train_loss])

                train_writer.add_summary(t_summ, iteration)
                print('Iter:{}/{}, Train Loss:{}'.format(
                    iteration,
                    args.train_iters,
                    t_loss))

                if iteration % args.val_every == 0:
                    v_loss = sess.run(val_loss)
                    print('Iter:{}, Val Loss:{}'.format(
                        iteration,
                        v_loss))

                if iteration % args.save_every == 0:
                    saver.save(
                        sess,
                        CKPT_PATH + 'iter:{}_val:{}'.format(
                            str(iteration),
                            str(round(v_loss, 3))))

                '''
                if iteration % args.plot_every == 0:
                    start_frames, end_frames, mid_frames,\
                        rec_mid_frames = sess.run(
                            [train_fFrames, train_lFrames,\
                                train_iFrames,\
                                train_rec_iFrames])

                    visualize_frames(
                        start_frames,
                        end_frames,
                        mid_frames,
                        rec_mid_frames,
                        iteration=iteration,
                        save_path=os.path.join(
                            CKPT_PATH,
                            'plots/'))
                '''

            coord.join(threads)

        except Exception as e:
            coord.request_stop(e)
def training(args):

    # DIRECTORY FOR CKPTS and META FILES
    # ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    ROOT_DIR = '/media/data/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(ROOT_DIR, args.experiment_name,
                                  'train.tfrecords')
    VAL_REC_PATH = os.path.join(ROOT_DIR, args.experiment_name,
                                'val.tfrecords')
    CKPT_PATH = os.path.join(ROOT_DIR, args.experiment_name,
                             args.ckpt_folder_name + '/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer([TRAIN_REC_PATH],
                                                     num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size,
                n_intermediate_frames=args.n_IF)

        val_queue = tf.train.string_input_producer([VAL_REC_PATH],
                                                   num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size,
                n_intermediate_frames=args.n_IF)

        # Apply gaussian blurring manually
        '''
        train_fFrames = gaussian_filter(train_fFrames, std=std_dev)
        train_lFrames = gaussian_filter(train_lFrames, std=std_dev)
        train_iFrames = gaussian_filter(train_iFrames, std=std_dev)
        val_fFrames = gaussian_filter(val_fFrames, std=std_dev)
        val_lFrames = gaussian_filter(val_lFrames, std=std_dev)
        val_iFrames = gaussian_filter(val_iFrames, std=std_dev)
        '''

        # TRAINABLE
        print('---------------------------------------------')
        print('----------------- GENERATOR -----------------')
        print('---------------------------------------------')
        with tf.variable_scope('generator'):
            train_rec_iFrames = generator_unet.build_generator(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True,
                n_IF=args.n_IF,
                starting_out_channels=args.starting_out_channels,
                use_attention=args.use_attention,
                spatial_attention=args.spatial_attention,
                is_verbose=True)

        print('---------------------------------------------')
        print('-------------- DISCRIMINATOR ----------------')
        print('---------------------------------------------')
        # discriminator for classifying real images
        with tf.variable_scope('discriminator'):
            train_real_output_discriminator = discriminator.build_discriminator(
                train_iFrames,
                use_batch_norm=True,
                is_training=True,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=True)
        # discriminator for classifying fake images
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            train_fake_output_discriminator = discriminator.build_discriminator(
                train_rec_iFrames,
                use_batch_norm=True,
                is_training=True,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=False)

        # VALIDATION
        with tf.variable_scope('generator', reuse=tf.AUTO_REUSE):
            val_rec_iFrames = generator_unet.build_generator(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                n_IF=args.n_IF,
                is_training=False,
                starting_out_channels=args.starting_out_channels,
                use_attention=args.use_attention,
                spatial_attention=args.spatial_attention,
                is_verbose=False)

        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            val_real_output_discriminator = discriminator.build_discriminator(
                val_iFrames,
                use_batch_norm=True,
                is_training=False,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=False)
        with tf.variable_scope('discriminator', reuse=tf.AUTO_REUSE):
            val_fake_output_discriminator = discriminator.build_discriminator(
                val_rec_iFrames,
                use_batch_norm=True,
                is_training=False,
                starting_out_channels=args.discri_starting_out_channels,
                is_verbose=False)

        if args.perceptual_loss_weight:
            # Weights should be kept locally ~ 500 MB space
            with tf.variable_scope('vgg16'):
                train_iFrames_features = vgg16.build_vgg16(
                    train_iFrames,
                    end_point=args.perceptual_loss_endpoint).features
            with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
                train_rec_iFrames_features = vgg16.build_vgg16(
                    train_rec_iFrames,
                    end_point=args.perceptual_loss_endpoint).features

        print('Global parameters:{}'.format(
            count_parameters(tf.global_variables())))
        print('Learnable model parameters:{}'.format(
            count_parameters(tf.trainable_variables())))

        # DEFINE GAN losses:
        train_discri_real_loss = tf.reduce_sum(
            tf.square(train_real_output_discriminator - 1)) / (2 *
                                                               args.batch_size)
        train_discri_fake_loss = tf.reduce_sum(
            tf.square(train_fake_output_discriminator)) / (2 * args.batch_size)
        train_discriminator_loss = train_discri_real_loss + train_discri_fake_loss

        train_generator_fake_loss = tf.reduce_sum(
            tf.square(train_fake_output_discriminator - 1)) / args.batch_size
        train_reconstruction_loss = l2_loss(
            train_rec_iFrames, train_iFrames) * args.reconstruction_loss_weight
        train_generator_loss = train_generator_fake_loss + train_reconstruction_loss

        val_discri_real_loss = tf.reduce_sum(
            tf.square(val_real_output_discriminator - 1)) / (2 *
                                                             args.batch_size)
        val_discri_fake_loss = tf.reduce_sum(
            tf.square(val_fake_output_discriminator)) / (2 * args.batch_size)
        val_discriminator_loss = val_discri_real_loss + val_discri_fake_loss

        val_generator_fake_loss = tf.reduce_sum(
            tf.square(val_fake_output_discriminator - 1)) / args.batch_size
        val_reconstruction_loss = l2_loss(
            val_rec_iFrames, val_iFrames) * args.reconstruction_loss_weight
        val_generator_loss = val_generator_fake_loss + val_reconstruction_loss

        if args.perceptual_loss_weight:
            train_percp_loss = perceptual_loss(train_rec_iFrames_features,
                                               train_iFrames_features)
            train_generator_loss += args.perceptual_loss_weight * train_percp_loss

        # SUMMARIES
        tf.summary.scalar('train_discri_real_loss', train_discri_real_loss)
        tf.summary.scalar('train_discri_fake_loss', train_discri_fake_loss)
        tf.summary.scalar('train_discriminator_loss', train_discriminator_loss)
        tf.summary.scalar('train_generator_fake_loss',
                          train_generator_fake_loss)
        tf.summary.scalar('train_reconstruction_loss',
                          train_reconstruction_loss)
        tf.summary.scalar('train_generator_loss', train_generator_loss)

        tf.summary.scalar('val_discri_real_loss', val_discri_real_loss)
        tf.summary.scalar('val_discri_fake_loss', val_discri_fake_loss)
        tf.summary.scalar('val_discriminator_loss', val_discriminator_loss)
        tf.summary.scalar('val_generator_fake_loss', val_generator_fake_loss)
        tf.summary.scalar('val_reconstruction_loss', val_reconstruction_loss)
        tf.summary.scalar('val_generator_loss', val_generator_loss)

        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(CKPT_PATH + 'train', sess.graph)

        # get variables responsible for generator and discriminator
        trainable_vars = tf.trainable_variables()
        generator_vars = [
            var for var in trainable_vars if 'generator' in var.name
        ]
        discriminator_vars = [
            var for var in trainable_vars if 'discriminator' in var.name
        ]

        # DEFINE OPTIMIZERS
        generator_optimizer = get_optimizer(train_generator_loss,
                                            optim_id=args.optim_id,
                                            learning_rate=args.learning_rate,
                                            use_batch_norm=True,
                                            var_list=generator_vars)
        discriminator_optimizer = get_optimizer(
            train_discriminator_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate * 2.,
            use_batch_norm=True,
            var_list=discriminator_vars)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)

        # START TRAINING HERE
        for iteration in range(args.train_iters):

            for d_iteration in range(args.disc_train_iters):
                disc_, td_loss = sess.run(
                    [discriminator_optimizer, train_discriminator_loss])

            gene_, tgf_loss, tr_loss, t_summ = sess.run(
                [generator_optimizer, train_generator_fake_loss,\
                    train_reconstruction_loss, merged])

            train_writer.add_summary(t_summ, iteration)

            print(
                'Iter:{}/{}, Disc. Loss:{}, Gen. Loss:{}, Rec. Loss:{}'.format(
                    iteration, args.train_iters, str(round(td_loss, 6)),
                    str(round(tgf_loss, 6)), str(round(tr_loss, 6))))

            if iteration % args.val_every == 0:
                vd_loss, vgf_loss, vr_loss = sess.run(
                    [val_discriminator_loss, val_generator_fake_loss,\
                        val_reconstruction_loss])
                print('Iter:{}, VAL Disc. Loss:{}, Gen. Loss:{}, Rec. Loss:{}'.
                      format(iteration, str(round(vd_loss, 6)),
                             str(round(vgf_loss, 6)), str(round(vr_loss, 6))))

            if iteration % args.save_every == 0:
                saver.save(
                    sess, CKPT_PATH +
                    'iter:{}_valDisc:{}_valGen:{}_valRec:{}'.format(
                        str(iteration), str(round(vd_loss, 6)),
                        str(round(vgf_loss, 6)), str(round(vr_loss, 6))))

            if iteration % args.plot_every == 0:
                start_frames, end_frames, mid_frames,\
                    rec_mid_frames = sess.run(
                        [train_fFrames, train_lFrames,\
                            train_iFrames,\
                            train_rec_iFrames])

                visualize_frames(start_frames,
                                 end_frames,
                                 mid_frames,
                                 rec_mid_frames,
                                 training=True,
                                 iteration=iteration,
                                 save_path=os.path.join(
                                     CKPT_PATH, 'train_plots/'))

                start_frames, end_frames, mid_frames,\
                    rec_mid_frames = sess.run(
                        [val_fFrames, val_lFrames,\
                            val_iFrames,
                            val_rec_iFrames])

                visualize_frames(start_frames,
                                 end_frames,
                                 mid_frames,
                                 rec_mid_frames,
                                 training=False,
                                 iteration=iteration,
                                 save_path=os.path.join(
                                     CKPT_PATH, 'validation_plots/'))

        print('Training complete.....')
def training_UNet3D_setup_tools(this_net, config_data):
	criterion = nn.CrossEntropyLoss()
	optimizer = op.get_optimizer(this_net, config_data)
	cetracker = ev.CrossEntropyLossTracker(config_data, display_every_n_minibatchs=2)
	return criterion, optimizer, cetracker
示例#7
0
def training(args):
    
    # DIRECTORY FOR CKPTS and META FILES
    # ROOT_DIR = '/neuhaus/movie/dataset/tf_records'
    ROOT_DIR = '/media/data/movie/dataset/tf_records'
    TRAIN_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'train.tfrecords')
    VAL_REC_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        'val.tfrecords')
    CKPT_PATH = os.path.join(
        ROOT_DIR,
        args.experiment_name,
        args.ckpt_folder_name + '/')

    # SCOPING BEGINS HERE
    with tf.Session().as_default() as sess:
        global_step = tf.train.get_global_step()

        train_queue = tf.train.string_input_producer(
            [TRAIN_REC_PATH], num_epochs=None)
        train_fFrames, train_lFrames, train_iFrames, train_mfn =\
            read_and_decode(
                filename_queue=train_queue,
                is_training=True,
                batch_size=args.batch_size,
                n_intermediate_frames=args.n_IF)

        val_queue = tf.train.string_input_producer(
            [VAL_REC_PATH], num_epochs=None)
        val_fFrames, val_lFrames, val_iFrames, val_mfn = \
            read_and_decode(
                filename_queue=val_queue,
                is_training=False,
                batch_size=args.batch_size,
                n_intermediate_frames=args.n_IF)

        # Apply gaussian blurring manually
        '''
        train_fFrames = gaussian_filter(train_fFrames, std=std_dev)
        train_lFrames = gaussian_filter(train_lFrames, std=std_dev)
        train_iFrames = gaussian_filter(train_iFrames, std=std_dev)
        val_fFrames = gaussian_filter(val_fFrames, std=std_dev)
        val_lFrames = gaussian_filter(val_lFrames, std=std_dev)
        val_iFrames = gaussian_filter(val_iFrames, std=std_dev)
        '''

        with tf.variable_scope('separate_bipn'):
            print('TRAIN FRAMES (first):')
            train_rec_iFrames = skip_separate_encoder_bipn.build_bipn(
                train_fFrames,
                train_lFrames,
                use_batch_norm=True,
                is_training=True,
                n_IF=args.n_IF,
                starting_out_channels=args.starting_out_channels,
                use_attention=args.use_attention,
                spatial_attention=args.spatial_attention,
                is_verbose=True)

        with tf.variable_scope('separate_bipn', reuse=tf.AUTO_REUSE):
            print('VAL FRAMES (first):')
            val_rec_iFrames = skip_separate_encoder_bipn.build_bipn(
                val_fFrames,
                val_lFrames,
                use_batch_norm=True,
                n_IF=args.n_IF,
                is_training=False,
                starting_out_channels=args.starting_out_channels,
                use_attention=args.use_attention,
                spatial_attention=args.spatial_attention,
                is_verbose=False)
            
        if args.perceptual_loss_weight:
            # Weights should be kept locally ~ 500 MB space
            with tf.variable_scope('vgg16'):
                train_iFrames_features = vgg16.build_vgg16(
                    train_iFrames,
                    end_point=args.perceptual_loss_endpoint).features
            with tf.variable_scope('vgg16', reuse=tf.AUTO_REUSE):
                train_rec_iFrames_features = vgg16.build_vgg16(
                    train_rec_iFrames,
                    end_point=args.perceptual_loss_endpoint).features

        print('Global parameters:{}'.format(
            count_parameters(tf.global_variables())))
        print('Learnable model parameters:{}'.format(
            count_parameters(tf.trainable_variables())))

        # DEFINE METRICS
        if args.loss_id == 0:
            train_loss = huber_loss(
                train_iFrames, train_rec_iFrames,
                delta=1.)
            val_loss = huber_loss(
                val_iFrames, val_rec_iFrames,
                delta=1.)

        elif args.loss_id == 1:
            train_loss = tf_l2_loss(
                train_iFrames, train_rec_iFrames)
            val_loss = tf_l2_loss(
                val_iFrames, val_rec_iFrames) 

        elif args.loss_id == 2:
            train_loss = l1_loss(
                train_iFrames, train_rec_iFrames)
            val_loss = l1_loss(
                val_iFrames, val_rec_iFrames)

        elif args.loss_id == 3:
            train_loss = ssim_loss(
                train_rec_iFrames, train_iFrames)
            val_loss = ssim_loss(
                val_rec_iFrames, val_iFrames)

        total_train_loss = train_loss
        tf.summary.scalar('train_main_loss', train_loss)
        tf.summary.scalar('total_val_loss', val_loss)

        if args.perceptual_loss_weight:
            train_perceptual_loss = tf_perceptual_loss(
                train_iFrames_features,
                train_rec_iFrames_features)

            tf.summary.scalar('train_perceptual_loss',\
                train_perceptual_loss)

            total_train_loss += train_perceptual_loss\
                * args.perceptual_loss_weight

        if args.weight_decay:
            decay_loss = ridge_weight_decay(
                tf.trainable_variables())

            tf.summary.scalar('ridge_l2_weight_decay',\
                decay_loss)

            total_train_loss += decay_loss\
                * args.weight_decay 

        # SUMMARIES
        tf.summary.scalar('total_train_loss',\
            total_train_loss)
        merged = tf.summary.merge_all()
        train_writer = tf.summary.FileWriter(
            CKPT_PATH + 'train',
            sess.graph)

        # DEFINE OPTIMIZER
        optimizer = get_optimizer(
            train_loss,
            optim_id=args.optim_id,
            learning_rate=args.learning_rate,
            use_batch_norm=True)

        init_op = tf.group(
            tf.global_variables_initializer(),
            tf.local_variables_initializer())
        saver = tf.train.Saver()

        sess.run(init_op)

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(
            coord=coord)

        # START TRAINING HERE
        for iteration in range(args.train_iters + 1):
            _, t_summ, t_loss = sess.run(
                [optimizer, merged, total_train_loss])

            train_writer.add_summary(t_summ, iteration)
            print('Iter:{}/{}, Train Loss:{}'.format(
                iteration,
                args.train_iters,
                t_loss))

            if iteration % args.val_every == 0:
                v_loss = sess.run(val_loss)
                print('Iter:{}, Val Loss:{}'.format(
                    iteration,
                    v_loss))

            if iteration % args.save_every == 0:
                saver.save(
                    sess,
                    CKPT_PATH + 'iter:{}_val:{}'.format(
                        str(iteration),
                        str(round(v_loss, 3))))

            if iteration % args.plot_every == 0:
                start_frames, end_frames, mid_frames,\
                    rec_mid_frames = sess.run(
                        [train_fFrames, train_lFrames,\
                            train_iFrames,\
                            train_rec_iFrames])

                visualize_frames(
                    start_frames,
                    end_frames,
                    mid_frames,
                    rec_mid_frames,
                    training=True,
                    iteration=iteration,
                    save_path=os.path.join(
                        CKPT_PATH,
                        'train_plots/'))

                start_frames, end_frames, mid_frames,\
                    rec_mid_frames = sess.run(
                        [val_fFrames, val_lFrames,\
                            val_iFrames,
                            val_rec_iFrames])

                visualize_frames(
                    start_frames,
                    end_frames,
                    mid_frames,
                    rec_mid_frames,
                    training=False,
                    iteration=iteration,
                    save_path=os.path.join(
                        CKPT_PATH,
                        'validation_plots/'))

        print('Training complete.....')