Esempio n. 1
0
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor
    detector_gather_device = args.num_gpus - 1
    sg2im_device = torch.device(args.num_gpus - 1)
    args.detector_gather_device = detector_gather_device
    args.sg2im_device = sg2im_device
    if not exists(args.output_dir):
        os.makedirs(args.output_dir)
    summary_writer = SummaryWriter(args.output_dir)

    # vocab, train_loader, val_loader = build_loaders(args)
    # self.ind_to_classes, self.ind_to_predicates
    vocab = {
        'object_idx_to_name': load_detector.train.ind_to_classes,
    }
    model, model_kwargs = build_model(
        args
    )  #, vocab)print(type(batch.imgs), len(batch.imgs), type(batch.imgs[0]))

    # model.type(float_dtype)
    model = model.to(sg2im_device)
    # model = DataParallel(model, list(range(args.num_gpus)))
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args)  #, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if obj_discriminator is not None:
        # obj_discriminator.type(float_dtype)
        obj_discriminator = obj_discriminator.to(sg2im_device)
        # obj_discriminator = DataParallel(obj_discriminator, list(range(args.num_gpus)))
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        # img_discriminator.type(float_dtype)
        img_discriminator = img_discriminator.to(sg2im_device)
        # img_discriminator = DataParallel(img_discriminator, list(range(args.num_gpus)))
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:
            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            # 'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        # for batch in train_loader:
        # for batch in train_detector.train_loader:
        for step, batch in enumerate(
                tqdm(load_detector.train_loader,
                     desc='Training Epoch %d' % epoch,
                     total=len(load_detector.train_loader))):
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            # batch = [tensor.cuda() for tensor in batch]
            # masks = None
            # if len(batch) == 6:
            #   imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            # elif len(batch) == 7:
            #   imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            # else:
            #   assert False
            # predicates = triples[:, 1]

            with timeit('forward', args.timing):
                # model_boxes = boxes
                # model_masks = masks
                # model_out = model(objs, triples, obj_to_img,
                #                   boxes_gt=model_boxes, masks_gt=model_masks)
                # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
                imgs = F.interpolate(batch.imgs,
                                     size=args.image_size).to(sg2im_device)
                with torch.no_grad():
                    if args.num_gpus > 2:
                        result = load_detector.detector.__getitem__(
                            batch, target_device=detector_gather_device)
                    else:
                        result = load_detector.detector[batch]
                objs = result.obj_preds
                boxes = result.rm_box_priors
                obj_to_img = result.im_inds
                obj_fmap = result.obj_fmap
                if args.num_gpus == 2:
                    objs = objs.to(sg2im_device)
                    boxes = boxes.to(sg2im_device)
                    obj_to_img = obj_to_img.to(sg2im_device)
                    obj_fmap = obj_fmap.to(sg2im_device)

                boxes /= load_detector.IM_SCALE
                # check if all image have detection
                cnt = torch.zeros(len(imgs)).byte()
                cnt[obj_to_img] += 1
                if (cnt > 0).sum() != len(imgs):
                    print("some imgs have no detection")
                    # print(obj_to_img)
                    print(cnt)
                    imgs = imgs[cnt]
                    obj_to_img_new = obj_to_img.clone()
                    for i in range(len(cnt)):
                        if cnt[i] == 0:
                            obj_to_img_new -= (obj_to_img > i).long()
                    obj_to_img = obj_to_img_new

                # assert (cnt > 0).sum() == len(imgs), "some imgs have no detection"
                model_out = model(obj_to_img, boxes, obj_fmap)
                imgs_pred = model_out

            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                # skip_pixel_loss = (model_boxes is None)
                skip_pixel_loss = False
                total_loss, losses = calculate_model_losses(
                    args, skip_pixel_loss, model, imgs, imgs_pred)

            if obj_discriminator is not None:
                scores_fake, ac_loss = obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)
                total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                      args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_obj_loss', weight)

            if img_discriminator is not None:
                scores_fake = img_discriminator(imgs_pred)
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            if obj_discriminator is not None:
                d_obj_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake, ac_loss_fake = obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                scores_real, ac_loss_real = obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

                d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
                d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

                optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                optimizer_d_obj.step()

            if img_discriminator is not None:
                d_img_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake = img_discriminator(imgs_fake)
                scores_real = img_discriminator(imgs)

                d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                G_loss_list = []
                for name, val in losses.items():
                    # print(' G [%s]: %.4f' % (name, val))
                    G_loss_list.append('[%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                    summary_writer.add_scalar("G_%s" % name, val, t)
                print("G: %s" % ", ".join(G_loss_list))
                checkpoint['losses_ts'].append(t)

                if obj_discriminator is not None:
                    D_obj_loss_list = []
                    for name, val in d_obj_losses.items():
                        # print(' D_obj [%s]: %.4f' % (name, val))
                        D_obj_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_obj_%s" % name, val, t)
                    print("D_obj: %s" % ", ".join(D_obj_loss_list))

                if img_discriminator is not None:
                    D_img_loss_list = []
                    for name, val in d_img_losses.items():
                        # print(' D_img [%s]: %.4f' % (name, val))
                        D_img_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_img_%s" % name, val, t)
                    print("D_img: %s" % ", ".join(D_img_loss_list))

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, t,
                                            load_detector.train_loader, model)
                # t_losses, t_samples, t_batch_data, t_avg_iou = train_results
                t_losses, t_samples, t_batch_data = train_results

                checkpoint['train_batch_data'].append(t_batch_data)
                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                # checkpoint['train_iou'].append(t_avg_iou)
                for name, images in t_samples.items():
                    # images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
                    # images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
                    summary_writer.add_image("train_%s" % name, images, t)

                print('checking on val')
                val_results = check_model(args, t, load_detector.val_loader,
                                          model)
                # val_losses, val_samples, val_batch_data, val_avg_iou = val_results
                val_losses, val_samples, val_batch_data = val_results
                checkpoint['val_samples'].append(val_samples)
                checkpoint['val_batch_data'].append(val_batch_data)
                # checkpoint['val_iou'].append(val_avg_iou)
                for name, images in val_samples.items():
                    # images = images * torch.tensor([0.229, 0.224, 0.225], device=images.device).reshape(1, 3, 1, 1)
                    # images = images + torch.tensor([0.485, 0.456, 0.406], device=images.device).reshape(1, 3, 1, 1)
                    summary_writer.add_image("val_%s" % name, images, t)

                # print('train iou: ', t_avg_iou)
                # print('val iou: ', val_avg_iou)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                    summary_writer.add_scalar("val_%s" % k, v, t)
                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
Esempio n. 2
0
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    print(model)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if obj_discriminator is not None:
        obj_discriminator.type(float_dtype)
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        img_discriminator.type(float_dtype)
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:
            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for batch in train_loader:
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            batch = [tensor.cuda() for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            else:
                assert False
            predicates = triples[:, 1]

            with timeit('forward', args.timing):
                model_boxes = boxes
                model_masks = masks
                model_out = model(objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=model_boxes,
                                  masks_gt=model_masks)
                imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                skip_pixel_loss = (model_boxes is None)
                total_loss, losses = calculate_model_losses(
                    args, skip_pixel_loss, model, imgs, imgs_pred, boxes,
                    boxes_pred, masks, masks_pred, predicates,
                    predicate_scores)

            if obj_discriminator is not None:
                scores_fake, ac_loss = obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)
                total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                      args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_obj_loss', weight)

            if img_discriminator is not None:
                scores_fake = img_discriminator(imgs_pred)
                weight = args.discriminator_loss_weight * args.d_img_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            if obj_discriminator is not None:
                d_obj_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake, ac_loss_fake = obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                scores_real, ac_loss_real = obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

                d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
                d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

                optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                optimizer_d_obj.step()

            if img_discriminator is not None:
                d_img_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake = img_discriminator(imgs_fake)
                scores_real = img_discriminator(imgs)

                d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                for name, val in losses.items():
                    print(' G [%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                checkpoint['losses_ts'].append(t)

                if obj_discriminator is not None:
                    for name, val in d_obj_losses.items():
                        print(' D_obj [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

                if img_discriminator is not None:
                    for name, val in d_img_losses.items():
                        print(' D_img [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, t, train_loader, model)
                t_losses, t_samples, t_batch_data, t_avg_iou = train_results

                checkpoint['train_batch_data'].append(t_batch_data)
                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                checkpoint['train_iou'].append(t_avg_iou)

                print('checking on val')
                val_results = check_model(args, t, val_loader, model)
                val_losses, val_samples, val_batch_data, val_avg_iou = val_results
                checkpoint['val_samples'].append(val_samples)
                checkpoint['val_batch_data'].append(val_batch_data)
                checkpoint['val_iou'].append(val_avg_iou)

                print('train iou: ', t_avg_iou)
                print('val iou: ', val_avg_iou)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
Esempio n. 3
0
def main(args):
    print(args)
    check_args(args)
    if not exists(args.output_dir):
        os.makedirs(args.output_dir)
    summary_writer = SummaryWriter(args.output_dir)

    if args.coco:
        train, val = CocoDetection.splits()
        val.ids = val.ids[:args.val_size]
        train.ids = train.ids
        train_loader, val_loader = CocoDataLoader.splits(train, val, batch_size=args.batch_size,
                                                         num_workers=args.num_workers,
                                                         num_gpus=args.num_gpus)
    else:
        train, val, _ = VG.splits(num_val_im=args.val_size, filter_non_overlap=False,
                                  filter_empty_rels=False, use_proposals=args.use_proposals)
        train_loader, val_loader = VGDataLoader.splits(train, val, batch_size=args.batch_size,
                                                       num_workers=args.num_workers,
                                                       num_gpus=args.num_gpus)
    print(train.ind_to_classes)
    os._exit(0)

    all_in_one_model = neural_motifs_sg2im_model(args, train.ind_to_classes)
    # Freeze the detector
    for n, param in all_in_one_model.detector.named_parameters():
        param.requires_grad = False
    all_in_one_model.cuda()
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    t, epoch, checkpoint = all_in_one_model.t, all_in_one_model.epoch, all_in_one_model.checkpoint
    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for step, batch in enumerate(tqdm(train_loader, desc='Training Epoch %d' % epoch, total=len(train_loader))):
            if t == args.eval_mode_after:
                print('switching to eval mode')
                all_in_one_model.model.eval()
                all_in_one_model.optimizer = optim.Adam(all_in_one_model.parameters(), lr=args.learning_rate)
            t += 1

            with timeit('forward', args.timing):
                result = all_in_one_model[batch]
                imgs, imgs_pred, objs, g_scores_fake_crop, g_obj_scores_fake_crop, g_scores_fake_img, \
                d_scores_fake_crop, d_obj_scores_fake_crop, d_scores_real_crop, d_obj_scores_real_crop, \
                d_scores_fake_img, d_scores_real_img = result.imgs, result.imgs_pred, result.objs, \
                result.g_scores_fake_crop, result.g_obj_scores_fake_crop, result.g_scores_fake_img, \
                result.d_scores_fake_crop, result.d_obj_scores_fake_crop, result.d_scores_real_crop, \
                result.d_obj_scores_real_crop, result.d_scores_fake_img, result.d_scores_real_img

            with timeit('loss', args.timing):
                total_loss, losses = calculate_model_losses(
                    args, imgs, imgs_pred)

                if all_in_one_model.obj_discriminator is not None:
                    total_loss = add_loss(total_loss, F.cross_entropy(g_obj_scores_fake_crop, objs), losses, 'ac_loss',
                                          args.ac_loss_weight)
                    weight = args.discriminator_loss_weight * args.d_obj_weight
                    total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_crop), losses,
                                          'g_gan_obj_loss', weight)

                if all_in_one_model.img_discriminator is not None:
                    weight = args.discriminator_loss_weight * args.d_img_weight
                    total_loss = add_loss(total_loss, gan_g_loss(g_scores_fake_img), losses,
                                          'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            with timeit('backward', args.timing):
                all_in_one_model.optimizer.zero_grad()
                total_loss.backward()
                all_in_one_model.optimizer.step()


            if all_in_one_model.obj_discriminator is not None:
                with timeit('d_obj loss', args.timing):
                    d_obj_losses = LossManager()
                    d_obj_gan_loss = gan_d_loss(d_scores_real_crop, d_scores_fake_crop)
                    d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                    d_obj_losses.add_loss(F.cross_entropy(d_obj_scores_real_crop, objs), 'd_ac_loss_real')
                    d_obj_losses.add_loss(F.cross_entropy(d_obj_scores_fake_crop, objs), 'd_ac_loss_fake')

                with timeit('d_obj backward', args.timing):
                    all_in_one_model.optimizer_d_obj.zero_grad()
                    d_obj_losses.total_loss.backward()
                    all_in_one_model.optimizer_d_obj.step()

            if all_in_one_model.img_discriminator is not None:
                with timeit('d_img loss', args.timing):
                    d_img_losses = LossManager()
                    d_img_gan_loss = gan_d_loss(d_scores_real_img, d_scores_fake_img)
                    d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                with timeit('d_img backward', args.timing):
                    all_in_one_model.optimizer_d_img.zero_grad()
                    d_img_losses.total_loss.backward()
                    all_in_one_model.optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                G_loss_list = []
                for name, val in losses.items():
                    G_loss_list.append('[%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                    summary_writer.add_scalar("G_%s" % name, val, t)
                print("G: %s" % ", ".join(G_loss_list))
                checkpoint['losses_ts'].append(t)

                if all_in_one_model.obj_discriminator is not None:
                    D_obj_loss_list = []
                    for name, val in d_obj_losses.items():
                        D_obj_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_obj_%s" % name, val, t)
                    print("D_obj: %s" % ", ".join(D_obj_loss_list))

                if all_in_one_model.img_discriminator is not None:
                    D_img_loss_list = []
                    for name, val in d_img_losses.items():
                        D_img_loss_list.append('[%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)
                        summary_writer.add_scalar("D_img_%s" % name, val, t)
                    print("D_img: %s" % ", ".join(D_img_loss_list))

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args, train_loader, all_in_one_model)
                t_losses, t_samples = train_results

                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                for name, images in t_samples.items():
                    summary_writer.add_image("train_%s" % name, images, t)

                print('checking on val')
                val_results = check_model(args, val_loader, all_in_one_model)
                val_losses, val_samples = val_results
                checkpoint['val_samples'].append(val_samples)
                for name, images in val_samples.items():
                    summary_writer.add_image("val_%s" % name, images, t)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                    summary_writer.add_scalar("val_%s" % k, v, t)
                checkpoint['model_state'] = all_in_one_model.model.state_dict()

                if all_in_one_model.obj_discriminator is not None:
                    checkpoint['d_obj_state'] = all_in_one_model.obj_discriminator.state_dict()
                    checkpoint['d_obj_optim_state'] = all_in_one_model.optimizer_d_obj.state_dict()

                if all_in_one_model.img_discriminator is not None:
                    checkpoint['d_img_state'] = all_in_one_model.img_discriminator.state_dict()
                    checkpoint['d_img_optim_state'] = all_in_one_model.optimizer_d_img.state_dict()

                checkpoint['optim_state'] = all_in_one_model.optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(args.output_dir,
                                               '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(args.output_dir,
                                               '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = ['model_state', 'optim_state', 'model_best_state',
                                 'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                                 'd_img_state', 'd_img_optim_state', 'd_img_best_state']
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
Esempio n. 4
0
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    print(model)

    if not os.path.isdir(args.output_dir):
        os.mkdir(args.output_dir)
        print('Created %s' % args.output_dir)

    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    vgg_featureExractor = FeatureExtractor(requires_grad=False).cuda()
    ## add code for training visualization
    logger = Logger(args.output_dir)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    if args.matching_aware_loss and args.sg_context_dim > 0:
        gan_g_matching_aware_loss, gan_d_matching_aware_loss = get_gan_losses(
            'matching_aware_gan')
    ############
    if obj_discriminator is not None:
        obj_discriminator.type(float_dtype)
        obj_discriminator.train()
        print(obj_discriminator)
        optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                           lr=args.learning_rate)

    if img_discriminator is not None:
        img_discriminator.type(float_dtype)
        img_discriminator.train()
        print(img_discriminator)
        optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                           lr=args.learning_rate)

    ##########################################################
    if args.kubernetes:

        ###ailab experiments
        from lib.exp import Client
        import kubernetes

        namespace = os.getenv("EXPERIMENT_NAMESPACE")
        job_name = os.getenv("JOB_NAME")
        client = Client(namespace)
        experiment = client.current_experiment()
        job = client.get_job(job_name)
        experiment_result = experiment.result(job)
        try:
            result = client.create_result(experiment_result)
        except kubernetes.client.rest.ApiException as e:
            body = json.loads(e.body)
            if body['reason'] != 'AlreadyExists':
                raise e
            result = client.get_result(experiment_result.name)

    restore_path = None
    if args.restore_from_checkpoint:
        restore_path = '%s_with_model.pt' % args.checkpoint_name
        restore_path = os.path.join(args.output_dir, restore_path)
    if restore_path is not None and os.path.isfile(restore_path):
        print('Restoring from checkpoint:')
        print(restore_path)
        checkpoint = torch.load(restore_path)
        model.load_state_dict(checkpoint['model_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

        if obj_discriminator is not None:
            obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
            optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

        if img_discriminator is not None:

            img_discriminator.load_state_dict(checkpoint['d_img_state'])
            optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

        t = checkpoint['counters']['t']
        if 0 <= args.eval_mode_after <= t:
            model.eval()
        else:
            model.train()
        epoch = checkpoint['counters']['epoch']
    else:
        t, epoch = 0, 0
        checkpoint = {
            'args': args.__dict__,
            'vocab': vocab,
            'model_kwargs': model_kwargs,
            'd_obj_kwargs': d_obj_kwargs,
            'd_img_kwargs': d_img_kwargs,
            'losses_ts': [],
            'losses': defaultdict(list),
            'd_losses': defaultdict(list),
            'checkpoint_ts': [],
            'train_batch_data': [],
            'train_samples': [],
            'train_iou': [],
            'val_batch_data': [],
            'val_samples': [],
            'val_losses': defaultdict(list),
            'val_iou': [],
            'norm_d': [],
            'norm_g': [],
            'counters': {
                't': None,
                'epoch': None,
            },
            'model_state': None,
            'model_best_state': None,
            'optim_state': None,
            'd_obj_state': None,
            'd_obj_best_state': None,
            'd_obj_optim_state': None,
            'd_img_state': None,
            'd_img_best_state': None,
            'd_img_optim_state': None,
            'best_t': [],
        }

    while True:
        if t >= args.num_iterations:
            break
        epoch += 1
        print('Starting epoch %d' % epoch)

        for batch in train_loader:
            if t == args.eval_mode_after:
                print('switching to eval mode')
                model.eval()
                optimizer = optim.Adam(model.parameters(),
                                       lr=args.learning_rate)
            t += 1
            batch = [tensor.cuda() for tensor in batch]
            masks = None
            if len(batch) == 6:
                imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
            elif len(batch) == 7:
                imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
            else:
                assert False
            predicates = triples[:, 1]

            with timeit('forward', args.timing):
                model_boxes = boxes
                model_masks = masks
                model_out = model(objs,
                                  triples,
                                  obj_to_img,
                                  boxes_gt=model_boxes,
                                  masks_gt=model_masks)
                # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
                imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores = model_out

            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                skip_pixel_loss = (model_boxes is None)
                # Skip the perceptual loss if using GT boxes
                skip_perceptual_loss = (model_boxes is None)

                if args.perceptual_loss_weight:
                    total_loss, losses = calculate_model_losses(
                        args,
                        skip_pixel_loss,
                        model,
                        imgs,
                        imgs_pred,
                        boxes,
                        boxes_pred,
                        masks,
                        masks_pred,
                        predicates,
                        predicate_scores,
                        skip_perceptual_loss,
                        perceptual_extractor=vgg_featureExractor)
                else:
                    total_loss, losses = calculate_model_losses(
                        args, skip_pixel_loss, model, imgs, imgs_pred, boxes,
                        boxes_pred, masks, masks_pred, predicates,
                        predicate_scores, skip_perceptual_loss)

            if obj_discriminator is not None:
                scores_fake, ac_loss = obj_discriminator(
                    imgs_pred, objs, boxes, obj_to_img)
                total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                      args.ac_loss_weight)
                weight = args.discriminator_loss_weight * args.d_obj_weight
                total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                      losses, 'g_gan_obj_loss', weight)

            if img_discriminator is not None:
                weight = args.discriminator_loss_weight * args.d_img_weight

                # scene_graph context by pooled GCNN features
                if args.sg_context_dim > 0:
                    ## concatenate => imgs, (layout_embedding), sg_context_pred
                    if args.layout_for_discrim == 1:
                        discrim_pred = torch.cat(
                            [imgs_pred, layout, sg_context_pred_d], dim=1)
                    else:
                        discrim_pred = torch.cat(
                            [imgs_pred, sg_context_pred_d], dim=1)

                    if args.matching_aware_loss:
                        # shuffle sg_context_p to use addional fake examples with real-images
                        matching_aware_size = sg_context_pred_d.size()[0]
                        s_sg_context_pred_d = sg_context_pred_d[torch.randperm(
                            matching_aware_size)]
                        if args.layout_for_discrim == 1:
                            match_aware_discrim_pred = torch.cat(
                                [imgs, layout, s_sg_context_pred_d], dim=1)
                        else:
                            match_aware_discrim_pred = torch.cat(
                                [imgs, s_sg_context_pred_d], dim=1)
                        discrim_pred = torch.cat(
                            [discrim_pred, match_aware_discrim_pred], dim=0)

                    scores_fake = img_discriminator(discrim_pred)
                    if args.matching_aware_loss:
                        total_loss = add_loss(
                            total_loss, gan_g_matching_aware_loss(scores_fake),
                            losses, 'g_gan_img_loss', weight)
                    else:
                        total_loss = add_loss(total_loss,
                                              gan_g_loss(scores_fake), losses,
                                              'g_gan_img_loss', weight)
                else:
                    scores_fake = img_discriminator(imgs_pred)
                    #weight = args.discriminator_loss_weight * args.d_img_weight
                    total_loss = add_loss(total_loss, gan_g_loss(scores_fake),
                                          losses, 'g_gan_img_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                total_loss.backward()
            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            if obj_discriminator is not None:
                d_obj_losses = LossManager()
                imgs_fake = imgs_pred.detach()
                scores_fake, ac_loss_fake = obj_discriminator(
                    imgs_fake, objs, boxes, obj_to_img)
                scores_real, ac_loss_real = obj_discriminator(
                    imgs, objs, boxes, obj_to_img)

                d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
                d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
                d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

                optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                optimizer_d_obj.step()

            if img_discriminator is not None:
                d_img_losses = LossManager()

                imgs_fake = imgs_pred.detach()

                if args.sg_context_dim_d > 0:
                    sg_context_p = sg_context_pred_d.detach()

                layout_p = layout.detach()
                # layout_gt_p = layout_gt.detach()

                ## concatenate=> imgs_fake, (layout_embedding), sg_context_pred
                if args.sg_context_dim > 0:
                    if args.layout_for_discrim:
                        discrim_fake = torch.cat(
                            [imgs_fake, layout_p, sg_context_p], dim=1)
                        discrim_real = torch.cat(
                            [imgs, layout_p, sg_context_p], dim=1)
                        # discrim_real = torch.cat([imgs, layout_gt_p, sg_context_p], dim=1 )
                    else:
                        discrim_fake = torch.cat([imgs_fake, sg_context_p],
                                                 dim=1)
                        discrim_real = torch.cat([imgs, sg_context_p], dim=1)

                    if args.matching_aware_loss:
                        # shuffle sg_context_p to use addional fake examples with real-images
                        matching_aware_size = sg_context_p.size()[0]
                        s_sg_context_p = sg_context_p[torch.randperm(
                            matching_aware_size)]
                        # s_sg_context_p = sg_context_p[torch.randperm(args.batch_size)]
                        if args.layout_for_discrim:
                            match_aware_discrim_fake = torch.cat(
                                [imgs, layout_p, s_sg_context_p], dim=1)
                        else:
                            match_aware_discrim_fake = torch.cat(
                                [imgs, s_sg_context_p], dim=1)
                        discrim_fake = torch.cat(
                            [discrim_fake, match_aware_discrim_fake], dim=0)

                    scores_fake = img_discriminator(discrim_fake)
                    scores_real = img_discriminator(discrim_real)

                    if args.matching_aware_loss:
                        d_img_gan_loss = gan_d_matching_aware_loss(
                            scores_real, scores_fake)
                    else:
                        d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
                else:
                    # imgs_fake = imgs_pred.detach()
                    scores_fake = img_discriminator(imgs_fake)
                    scores_real = img_discriminator(imgs)

                if args.matching_aware_loss:
                    d_img_gan_loss = gan_d_matching_aware_loss(
                        scores_real, scores_fake)
                else:
                    d_img_gan_loss = gan_d_loss(scores_real, scores_fake)

                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

                optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                optimizer_d_img.step()

            if t % args.print_every == 0:
                print('t = %d / %d' % (t, args.num_iterations))
                for name, val in losses.items():
                    print(' G [%s]: %.4f' % (name, val))
                    checkpoint['losses'][name].append(val)
                checkpoint['losses_ts'].append(t)

                if obj_discriminator is not None:
                    for name, val in d_obj_losses.items():
                        print(' D_obj [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

                if img_discriminator is not None:
                    for name, val in d_img_losses.items():
                        print(' D_img [%s]: %.4f' % (name, val))
                        checkpoint['d_losses'][name].append(val)

                # ================================================================== #
                #                        Tensorboard Logging                         #
                # ================================================================== #

                # 1. Log scalar values (scalar summary)
                for name, val in losses.items():
                    logger.scalar_summary(name, val, t)
                if obj_discriminator is not None:
                    for name, val in d_obj_losses.items():
                        logger.scalar_summary(name, val, t)
                if img_discriminator is not None:
                    for name, val in d_img_losses.items():
                        logger.scalar_summary(name, val, t)

            if t % args.checkpoint_every == 0:
                print('checking on train')
                train_results = check_model(args,
                                            t,
                                            train_loader,
                                            model,
                                            logger=logger,
                                            log_tag='Train',
                                            write_images=False)
                t_losses, t_samples, t_batch_data, t_avg_iou = train_results

                checkpoint['train_batch_data'].append(t_batch_data)
                checkpoint['train_samples'].append(t_samples)
                checkpoint['checkpoint_ts'].append(t)
                checkpoint['train_iou'].append(t_avg_iou)

                print('checking on val')
                val_results = check_model(args,
                                          t,
                                          val_loader,
                                          model,
                                          logger=logger,
                                          log_tag='Validation',
                                          write_images=True)

                val_losses, val_samples, val_batch_data, val_avg_iou = val_results
                checkpoint['val_samples'].append(val_samples)
                checkpoint['val_batch_data'].append(val_batch_data)
                checkpoint['val_iou'].append(val_avg_iou)

                print('train iou: ', t_avg_iou)
                print('val iou: ', val_avg_iou)

                for k, v in val_losses.items():
                    checkpoint['val_losses'][k].append(v)
                checkpoint['model_state'] = model.state_dict()

                if obj_discriminator is not None:
                    checkpoint['d_obj_state'] = obj_discriminator.state_dict()
                    checkpoint[
                        'd_obj_optim_state'] = optimizer_d_obj.state_dict()

                if img_discriminator is not None:
                    checkpoint['d_img_state'] = img_discriminator.state_dict()
                    checkpoint[
                        'd_img_optim_state'] = optimizer_d_img.state_dict()

                checkpoint['optim_state'] = optimizer.state_dict()
                checkpoint['counters']['t'] = t
                checkpoint['counters']['epoch'] = epoch
                checkpoint_path = os.path.join(
                    args.output_dir,
                    #'%s_with_model_%d.pt' %(args.checkpoint_name, t)
                    '%s_with_model.pt' % args.checkpoint_name)
                print('Saving checkpoint to ', checkpoint_path)
                torch.save(checkpoint, checkpoint_path)

                # Save another checkpoint without any model or optim state
                checkpoint_path = os.path.join(
                    args.output_dir, '%s_no_model.pt' % args.checkpoint_name)
                key_blacklist = [
                    'model_state', 'optim_state', 'model_best_state',
                    'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                    'd_img_state', 'd_img_optim_state', 'd_img_best_state'
                ]
                small_checkpoint = {}
                for k, v in checkpoint.items():
                    if k not in key_blacklist:
                        small_checkpoint[k] = v
                torch.save(small_checkpoint, checkpoint_path)
    def D_step(result):
        imgs, imgs_pred, objs, \
        d_scores_fake_crop, d_obj_scores_fake_crop, d_scores_real_crop, \
        d_obj_scores_real_crop, d_scores_fake_img, d_scores_real_img, \
        d_obj_gp, d_img_gp \
        = result.imgs, result.imgs_pred, result.objs, \
          result.d_scores_fake_crop, result.d_obj_scores_fake_crop, result.d_scores_real_crop, \
          result.d_obj_scores_real_crop, result.d_scores_fake_img, result.d_scores_real_img, \
          result.d_obj_gp, result.d_img_gp
        d_rec_feature_fake_crop, d_rec_feature_real_crop = result.d_rec_feature_fake_crop, result.d_rec_feature_real_crop
        obj_fmaps = result.obj_fmaps
        d_scores_fake_bg, d_scores_real_bg, d_bg_gp = result.d_scores_fake_bg, result.d_scores_real_bg, result.d_bg_gp

        d_obj_losses, d_img_losses, d_bg_losses = None, None, None
        if all_in_one_model.obj_discriminator is not None:
            with timeit('d_obj loss', args.timing):
                d_obj_losses = LossManager()
                if args.d_obj_weight > 0:
                    d_obj_gan_loss = gan_d_loss(d_scores_real_crop,
                                                d_scores_fake_crop)
                    d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                    if args.gan_loss_type == 'wgan-gp':
                        d_obj_losses.add_loss(d_obj_gp.mean(), 'd_obj_gp',
                                              args.d_obj_gp_weight)
                if args.ac_loss_weight > 0:
                    d_obj_losses.add_loss(
                        F.cross_entropy(d_obj_scores_real_crop, objs),
                        'd_ac_loss_real')
                    d_obj_losses.add_loss(
                        F.cross_entropy(d_obj_scores_fake_crop, objs),
                        'd_ac_loss_fake')
                if args.d_obj_rec_feat_weight > 0:
                    d_obj_losses.add_loss(
                        F.l1_loss(d_rec_feature_fake_crop, obj_fmaps),
                        'd_obj_fea_rec_loss_fake')
                    d_obj_losses.add_loss(
                        F.l1_loss(d_rec_feature_real_crop, obj_fmaps),
                        'd_obj_fea_rec_loss_real')

            with timeit('d_obj backward', args.timing):
                all_in_one_model.optimizer_d_obj.zero_grad()
                d_obj_losses.total_loss.backward()
                all_in_one_model.optimizer_d_obj.step()

        if all_in_one_model.img_discriminator is not None:
            with timeit('d_img loss', args.timing):
                d_img_losses = LossManager()
                d_img_gan_loss = gan_d_loss(d_scores_real_img,
                                            d_scores_fake_img)
                d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')
                if args.gan_loss_type == 'wgan-gp':
                    d_img_losses.add_loss(d_img_gp.mean(), 'd_img_gp',
                                          args.d_img_gp_weight)

            with timeit('d_img backward', args.timing):
                all_in_one_model.optimizer_d_img.zero_grad()
                d_img_losses.total_loss.backward()
                all_in_one_model.optimizer_d_img.step()

        if all_in_one_model.bg_discriminator is not None:
            with timeit('d_bg loss', args.timing):
                d_bg_losses = LossManager()
                d_bg_gan_loss = gan_d_loss(d_scores_real_bg, d_scores_fake_bg)
                d_bg_losses.add_loss(d_bg_gan_loss, 'd_bg_gan_loss')
                if args.gan_loss_type == 'wgan-gp':
                    d_bg_losses.add_loss(d_bg_gp.mean(), 'd_bg_gp',
                                         args.d_bg_gp_weight)

            with timeit('d_bg backward', args.timing):
                all_in_one_model.optimizer_d_bg.zero_grad()
                d_bg_losses.total_loss.backward()
                all_in_one_model.optimizer_d_bg.step()

        return d_obj_losses, d_img_losses, d_bg_losses
Esempio n. 6
0
def main(args):
    #print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    model = model.cuda()

    layoutgen = LayoutGenerator(args.batch_size,
                                args.max_objects_per_image + 1, 184).cuda()
    optimizer_params = list(model.parameters()) + list(layoutgen.parameters())
    optimizer = torch.optim.Adam(params=optimizer_params,
                                 lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    obj_discriminator = obj_discriminator.cuda()
    img_discriminator = img_discriminator.cuda()

    H, W = args.image_size
    layout_discriminator = LayoutDiscriminator(args.batch_size,
                                               args.max_objects_per_image + 1,
                                               184, H, W).cuda()

    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    obj_discriminator.type(float_dtype)
    obj_discriminator.train()
    optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                       lr=args.learning_rate)

    img_discriminator.type(float_dtype)
    img_discriminator.train()
    optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                       lr=args.learning_rate)

    optimizer_d_layout = torch.optim.Adam(
        params=layout_discriminator.parameters(), lr=args.learning_rate)

    epoch = 3

    if (args.checkpoint_start_from is not None):
        model_path = args.checkpoint_start_from

        checkpoint = torch.load(model_path)
        #epoch = checkpoint['args']['epoch']
        model.load_state_dict(checkpoint['model_state'])
        layoutgen.load_state_dict(checkpoint['layout_gen'])

        obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
        img_discriminator.load_state_dict(checkpoint['d_img_state'])
        layout_discriminator.load_state_dict(checkpoint['d_layout_state'])

        optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])
        optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])
        optimizer_d_layout.load_state_dict(checkpoint['d_layout_optim_state'])
        optimizer.load_state_dict(checkpoint['optim_state'])

    while True:
        if (epoch >= args.num_epochs):
            break

        epoch += 1
        print('Starting epoch %d' % epoch)

        for batchnum, batch in enumerate(tqdm(train_loader)):
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, combined, all_num_objs = batch
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, combined, all_num_objs = imgs.cuda(
            ), objs.cuda(), boxes.cuda(), masks.cuda(), triples.cuda(
            ), obj_to_img.cuda(), triple_to_img.cuda(), combined.cuda(
            ), all_num_objs.cuda()

            if (imgs.shape[0] < args.batch_size):
                print('current size was', imgs.shape[0])
                continue
            #print('\n\nimages',imgs.shape,objs.shape,boxes.shape,masks.shape,combined.shape)
            #print(all_num_objs)
            zlist = []
            for i in range(args.batch_size):
                geo_z = torch.normal(0,
                                     1,
                                     size=(args.max_objects_per_image + 1, 4))
                z = torch.FloatTensor(geo_z)
                zlist.append(z)

            zlist = torch.stack(zlist).cuda()
            zlist = torch.cat((zlist, combined[:, :, 4:]), dim=2)

            feature_vectors, logit_boxes = layoutgen(zlist.cuda())

            generated_boxes = 1 / (1 + torch.exp(-1 * logit_boxes))

            new_gen_boxes = torch.empty((0, 4)).cuda()
            new_feature_vecs = torch.empty((0, 128)).cuda()

            for kb in range(args.batch_size):
                new_gen_boxes = torch.cat([
                    new_gen_boxes,
                    torch.squeeze(generated_boxes[kb, :all_num_objs[kb], :4])
                ],
                                          dim=0)
                new_feature_vecs = torch.cat([
                    new_feature_vecs,
                    torch.squeeze(feature_vectors[kb, :all_num_objs[kb], :])
                ],
                                             dim=0)

            boxes_pred = new_gen_boxes
            model_boxes = generated_boxes
            model_masks = None
            triples = None

            imgs_pred = model(new_feature_vecs, new_gen_boxes, triples,
                              obj_to_img)
            #boxes_gt=model_boxes, masks_gt=model_masks)
            #imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out

            # Skip the pixel loss if using GT boxes

            total_loss, losses = calculate_model_losses(
                args, model, imgs, imgs_pred, boxes, boxes_pred, logit_boxes,
                generated_boxes, combined)

            scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes,
                                                     obj_to_img)
            total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                  args.ac_loss_weight)
            weight = args.discriminator_loss_weight * args.d_obj_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_obj_loss', weight)

            scores_fake = img_discriminator(imgs_pred)
            weight = args.discriminator_loss_weight * args.d_img_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_img_loss', weight)

            scores_fake = layout_discriminator(logit_boxes)
            weight = args.discriminator_loss_weight * args.d_img_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_layout_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()

            #print('Total loss:',total_loss)
            total_loss.backward()

            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            d_obj_losses = LossManager()
            imgs_fake = imgs_pred.detach().cuda()
            scores_fake, ac_loss_fake = obj_discriminator(
                imgs_fake, objs, boxes, obj_to_img)
            scores_real, ac_loss_real = obj_discriminator(
                imgs, objs, boxes, obj_to_img)

            d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
            d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
            d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

            optimizer_d_obj.zero_grad()
            d_obj_losses.total_loss.backward()
            optimizer_d_obj.step()

            d_img_losses = LossManager()
            imgs_fake = imgs_pred.detach().cuda()
            scores_fake = img_discriminator(imgs_fake)
            scores_real = img_discriminator(imgs)

            d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

            optimizer_d_img.zero_grad()
            d_img_losses.total_loss.backward()
            optimizer_d_img.step()

            d_layout_losses = LossManager()
            layout_fake = logit_boxes.detach()
            scores_fake = layout_discriminator(layout_fake)
            scores_real = layout_discriminator(combined)

            d_layout_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_layout_losses.add_loss(d_layout_gan_loss, 'd_layout_gan_loss')

            optimizer_d_layout.zero_grad()
            d_layout_losses.total_loss.backward()
            optimizer_d_layout.step()

            if ((batchnum + 1) % 10 == 0):
                towrite = '\n|Epoch:' + str(epoch) + ' | Batch:' + str(
                    batchnum) + ' | layout Loss:' + str(
                        d_layout_losses.total_loss.item()
                    ) + ' | img disc loss:' + str(d_img_losses.total_loss.item(
                    )) + ' | obj disc loss:' + str(
                        d_obj_losses.total_loss.item(
                        )) + ' | total gen loss:' + str(total_loss.item())
                with open('stats/training_stats.txt', 'a+') as f:
                    f.write(towrite)

            if ((batchnum + 1) % 100 == 0):
                checkpoint = {
                    'args': {
                        'epoch': epoch
                    },
                    'model_state': model.state_dict(),
                    'layout_gen': layoutgen.state_dict(),
                    'd_obj_state': obj_discriminator.state_dict(),
                    'd_img_state': img_discriminator.state_dict(),
                    'd_layout_state': layout_discriminator.state_dict(),
                    'd_obj_optim_state': optimizer_d_obj.state_dict(),
                    'd_img_optim_state': optimizer_d_img.state_dict(),
                    'd_layout_optim_state': optimizer_d_layout.state_dict(),
                    'optim_state': optimizer.state_dict(),
                }
                print('Saving checkpoint to ', args.checkpoint_folder)
                checkpoint_path = os.path.join(
                    args.checkpoint_folder, 'epoch_' + str(epoch) + '_batch_' +
                    str(batchnum) + '_with_model.pt')
                torch.save(checkpoint, checkpoint_path)

        checkpoint = {
            'args': {
                'epoch': epoch
            },
            'model_state': model.state_dict(),
            'layout_gen': layoutgen.state_dict(),
            'd_obj_state': obj_discriminator.state_dict(),
            'd_img_state': img_discriminator.state_dict(),
            'd_layout_state': layout_discriminator.state_dict(),
            'd_obj_optim_state': optimizer_d_obj.state_dict(),
            'd_img_optim_state': optimizer_d_img.state_dict(),
            'd_layout_optim_state': optimizer_d_layout.state_dict(),
            'optim_state': optimizer.state_dict(),
        }
        print('Saving checkpoint to ', args.checkpoint_folder)
        checkpoint_path = os.path.join(
            args.checkpoint_folder, 'epoch_' + str(epoch) + '_with_model.pt')
        torch.save(checkpoint, checkpoint_path)
def main(args):
    print(args)
    check_args(args)
    float_dtype = torch.cuda.FloatTensor
    long_dtype = torch.cuda.LongTensor

    vocab, train_loader, val_loader = build_loaders(args)
    model, model_kwargs = build_model(args, vocab)
    model.type(float_dtype)
    model = model.cuda()
    layoutgen = LayoutGenerator(args.batch_size,
                                args.max_objects_per_image + 1, 184).cuda()
    optimizer_params = list(model.parameters()) + list(layoutgen.parameters())
    optimizer = torch.optim.Adam(params=optimizer_params,
                                 lr=args.learning_rate)

    obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
    img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
    obj_discriminator = obj_discriminator.cuda()
    img_discriminator = img_discriminator.cuda()
    layout_discriminator = LayoutDiscriminator(args.batch_size,
                                               args.max_objects_per_image + 1,
                                               184, 64, 64).cuda()

    gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)

    obj_discriminator.type(float_dtype)
    obj_discriminator.train()
    optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                       lr=args.learning_rate)

    img_discriminator.type(float_dtype)
    img_discriminator.train()
    optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                       lr=args.learning_rate)

    optimizer_d_layout = torch.optim.Adam(
        params=layout_discriminator.parameters(), lr=args.learning_rate)

    model_path = 'stats/epoch_2_batch_399_with_model.pt'
    checkpoint = torch.load(model_path)
    model.load_state_dict(checkpoint['model_state'])
    layoutgen.load_state_dict(checkpoint['layout_gen'])

    obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
    img_discriminator.load_state_dict(checkpoint['d_img_state'])
    layout_discriminator.load_state_dict(checkpoint['d_layout_state'])

    optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])
    optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])
    optimizer_d_layout.load_state_dict(checkpoint['d_layout_optim_state'])
    optimizer.load_state_dict(checkpoint['optim_state'])

    # checkpoint = torch.load('sg2im-models/coco64.pt')
    # model.load_state_dict(checkpoint['model_state'])
    # 0/0

    # 'model_state':model.state_dict(),
    #   'layout_gen':layoutgen.state_dict(),
    #   'd_obj_state': obj_discriminator.state_dict(),
    #   'd_img_state': img_discriminator.state_dict(),
    #   'd_layout_state':layout_discriminator.state_dict(),

    #   'd_obj_optim_state': optimizer_d_obj.state_dict(),
    #   'd_img_optim_state': optimizer_d_img.state_dict(),
    #   'd_layout_optim_state':optimizer_d_layout.state_dict(),
    #   'optim_state': optimizer.state_dict(),

    # restore_path = None
    # if args.restore_from_checkpoint:
    #   restore_path = 'stats/%s_with_model.pt' % args.checkpoint_name
    #   restore_path = os.path.join(args.output_dir, restore_path)

    # if restore_path is not None and os.path.isfile(restore_path):
    #   print('Restoring from checkpoint:')
    #   print(restore_path)
    #   checkpoint = torch.load(restore_path)
    #   model.load_state_dict(checkpoint['model_state'])
    #   optimizer.load_state_dict(checkpoint['optim_state'])

    #   obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
    #   optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

    #   img_discriminator.load_state_dict(checkpoint['d_img_state'])
    #   optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

    #   t = checkpoint['counters']['t']
    #   if 0 <= args.eval_mode_after <= t:
    #     model.eval()
    #   else:
    #     model.train()
    #   epoch = checkpoint['counters']['epoch']
    # else:
    #   t, epoch = 0, 0
    #   checkpoint = {
    #     'args': args.__dict__,
    #     'vocab': vocab,
    #     'model_kwargs': model_kwargs,
    #     'd_obj_kwargs': d_obj_kwargs,
    #     'd_img_kwargs': d_img_kwargs,
    #     'losses_ts': [],
    #     'losses': defaultdict(list),
    #     'd_losses': defaultdict(list),
    #     'checkpoint_ts': [],
    #     'train_batch_data': [],
    #     'train_samples': [],
    #     'train_iou': [],
    #     'val_batch_data': [],
    #     'val_samples': [],
    #     'val_losses': defaultdict(list),
    #     'val_iou': [],
    #     'norm_d': [],
    #     'norm_g': [],
    #     'counters': {
    #       't': None,
    #       'epoch': None,
    #     },
    #     'model_state': None, 'model_best_state': None, 'optim_state': None,
    #     'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
    #     'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
    #     'best_t': [],
    #   }

    epoch = 2
    while True:
        if (epoch >= 20):
            break

        epoch += 1
        print('Starting epoch %d' % epoch)

        for batchnum, batch in enumerate(tqdm(train_loader)):
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, combined, all_num_objs = batch
            imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, combined, all_num_objs = imgs.cuda(
            ), objs.cuda(), boxes.cuda(), masks.cuda(), triples.cuda(
            ), obj_to_img.cuda(), triple_to_img.cuda(), combined.cuda(
            ), all_num_objs.cuda()

            if (imgs.shape[0] < args.batch_size):
                print('current size was', imgs.shape[0])
                continue
            #print('\n\nimages',imgs.shape,objs.shape,boxes.shape,masks.shape,combined.shape)
            #print(all_num_objs)
            zlist = []
            for i in range(args.batch_size):
                geo_z = torch.normal(0,
                                     1,
                                     size=(args.max_objects_per_image + 1, 4))
                z = torch.FloatTensor(geo_z)
                zlist.append(z)

            zlist = torch.stack(zlist).cuda()
            zlist = torch.cat((zlist, combined[:, :, 4:]), dim=2)

            feature_vectors, logit_boxes = layoutgen(zlist.cuda())
            generated_boxes = 1 / (1 + torch.exp(-1 * logit_boxes))

            new_gen_boxes = torch.empty((0, 4)).cuda()
            new_feature_vecs = torch.empty((0, 128)).cuda()

            for kb in range(args.batch_size):
                new_gen_boxes = torch.cat([
                    new_gen_boxes,
                    torch.squeeze(generated_boxes[kb, :all_num_objs[kb], :4])
                ],
                                          dim=0)
                new_feature_vecs = torch.cat([
                    new_feature_vecs,
                    torch.squeeze(feature_vectors[kb, :all_num_objs[kb], :])
                ],
                                             dim=0)

            #print('Shape of new gen boxes:',new_gen_boxes.shape)
            #print('Shape of new feature vec:',new_feature_vecs.shape)
            boxes_pred = new_gen_boxes

            with timeit('forward', args.timing):
                #model_boxes = boxes
                model_boxes = generated_boxes
                #model_masks = masks
                model_masks = None
                triples = None

                imgs_pred = model(new_feature_vecs, new_gen_boxes, triples,
                                  obj_to_img)
                #boxes_gt=model_boxes, masks_gt=model_masks)
                #imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
            with timeit('loss', args.timing):
                # Skip the pixel loss if using GT boxes
                skip_pixel_loss = (model_boxes is None)
                total_loss, losses = calculate_model_losses(
                    args, model, imgs, imgs_pred, boxes, boxes_pred,
                    logit_boxes, generated_boxes, combined)

            scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes,
                                                     obj_to_img)
            total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                                  args.ac_loss_weight)
            weight = args.discriminator_loss_weight * args.d_obj_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_obj_loss', weight)

            scores_fake = img_discriminator(imgs_pred)
            weight = args.discriminator_loss_weight * args.d_img_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_img_loss', weight)

            scores_fake = layout_discriminator(logit_boxes)
            weight = args.discriminator_loss_weight * args.d_img_weight
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                  'g_gan_layout_loss', weight)

            losses['total_loss'] = total_loss.item()
            if not math.isfinite(losses['total_loss']):
                print('WARNING: Got loss = NaN, not backpropping')
                continue

            optimizer.zero_grad()
            with timeit('backward', args.timing):
                #print('Total loss:',total_loss)
                total_loss.backward()

            optimizer.step()
            total_loss_d = None
            ac_loss_real = None
            ac_loss_fake = None
            d_losses = {}

            d_obj_losses = LossManager()
            imgs_fake = imgs_pred.detach().cuda()
            scores_fake, ac_loss_fake = obj_discriminator(
                imgs_fake, objs, boxes, obj_to_img)
            scores_real, ac_loss_real = obj_discriminator(
                imgs, objs, boxes, obj_to_img)

            d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
            d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
            d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

            optimizer_d_obj.zero_grad()
            d_obj_losses.total_loss.backward()
            optimizer_d_obj.step()

            d_img_losses = LossManager()
            imgs_fake = imgs_pred.detach().cuda()
            scores_fake = img_discriminator(imgs_fake)
            scores_real = img_discriminator(imgs)

            d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')

            optimizer_d_img.zero_grad()
            d_img_losses.total_loss.backward()
            optimizer_d_img.step()

            d_layout_losses = LossManager()
            layout_fake = logit_boxes.detach()
            scores_fake = layout_discriminator(layout_fake)
            scores_real = layout_discriminator(combined)

            d_layout_gan_loss = gan_d_loss(scores_real, scores_fake)
            d_layout_losses.add_loss(d_layout_gan_loss, 'd_layout_gan_loss')

            optimizer_d_layout.zero_grad()
            d_layout_losses.total_loss.backward()
            optimizer_d_layout.step()

            if ((batchnum + 1) % 10 == 0):
                towrite = '\n|Epoch:' + str(epoch) + ' | layout Loss:' + str(
                    d_layout_losses.total_loss) + ' | img disc loss:' + str(
                        d_img_losses.total_loss) + ' | obj disc loss:' + str(
                            d_obj_losses.total_loss
                        ) + ' | total gen loss:' + str(total_loss)
                with open('stats/training_stats.txt', 'a+') as f:
                    f.write(towrite)

            if ((batchnum + 1) % 100 == 0):
                checkpoint = {
                    'model_state': model.state_dict(),
                    'layout_gen': layoutgen.state_dict(),
                    'd_obj_state': obj_discriminator.state_dict(),
                    'd_img_state': img_discriminator.state_dict(),
                    'd_layout_state': layout_discriminator.state_dict(),
                    'd_obj_optim_state': optimizer_d_obj.state_dict(),
                    'd_img_optim_state': optimizer_d_img.state_dict(),
                    'd_layout_optim_state': optimizer_d_layout.state_dict(),
                    'optim_state': optimizer.state_dict(),
                }
                print('Saving checkpoint to ', 'stats/')
                checkpoint_path = os.path.join(
                    'stats/', 'epoch_' + str(epoch) + '_batch_' +
                    str(batchnum) + '_with_model.pt')
                torch.save(checkpoint, checkpoint_path)

        checkpoint = {
            'model_state': model.state_dict(),
            'layout_gen': layoutgen.state_dict(),
            'd_obj_state': obj_discriminator.state_dict(),
            'd_img_state': img_discriminator.state_dict(),
            'd_layout_state': layout_discriminator.state_dict(),
            'd_obj_optim_state': optimizer_d_obj.state_dict(),
            'd_img_optim_state': optimizer_d_img.state_dict(),
            'd_layout_optim_state': optimizer_d_layout.state_dict(),
            'optim_state': optimizer.state_dict(),
        }
        print('Saving checkpoint to ', 'stats/')
        checkpoint_path = os.path.join(
            'stats/', 'epoch_' + str(epoch) + '_with_model.pt')
        torch.save(checkpoint, checkpoint_path)
Esempio n. 8
0
def main(args):

  device = torch.device('cuda:0')
  if not torch.cuda.is_available():
   print('WARNING: CUDA not available; falling back to CPU')
   torch.cuda.current_device()
 
  print(args)
  check_args(args)
  float_dtype = torch.cuda.FloatTensor
  long_dtype = torch.cuda.LongTensor

  vocab, train_loader, val_loader = build_loaders(args)
  model, model_kwargs = build_model(args, vocab)
  model.type(float_dtype)
  print(model)

  if not os.path.isdir(args.output_dir):
    os.mkdir(args.output_dir)
    print('Created %s' %args.output_dir)

  optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
  vgg_featureExractor = FeatureExtractor(requires_grad=False).cuda()
  ## add code for training visualization
  logger = Logger(args.output_dir)

  obj_discriminator, d_obj_kwargs = build_obj_discriminator(args, vocab)
  img_discriminator, d_img_kwargs = build_img_discriminator(args, vocab)
  gan_g_loss, gan_d_loss = get_gan_losses(args.gan_loss_type)
  
  if args.matching_aware_loss and args.sg_context_dim > 0:
    gan_g_matching_aware_loss, gan_d_matching_aware_loss = get_gan_losses('matching_aware_gan')

  ### quick hack
  obj_discriminator = None
  img_discriminator = None


  ############
  if obj_discriminator is not None:
    obj_discriminator.type(float_dtype)
    obj_discriminator.train()
    print(obj_discriminator)
    optimizer_d_obj = torch.optim.Adam(obj_discriminator.parameters(),
                                       lr=args.learning_rate)

  if img_discriminator is not None:
    img_discriminator.type(float_dtype)
    img_discriminator.train()
    print(img_discriminator)
    optimizer_d_img = torch.optim.Adam(img_discriminator.parameters(),
                                       lr=args.learning_rate)

  restore_path = None
  if args.checkpoint_start_from is not None:
    restore_path = args.checkpoint_start_from
  elif args.restore_from_checkpoint:
    restore_path = '%s_with_model.pt' % args.checkpoint_name
    restore_path = os.path.join(args.output_dir, restore_path)
  if restore_path is not None and os.path.isfile(restore_path):
    print('Restoring from checkpoint:')
    print(restore_path)
    checkpoint = torch.load(restore_path)
    model.load_state_dict(checkpoint['model_state'])
    optimizer.load_state_dict(checkpoint['optim_state'])

    if obj_discriminator is not None:
      obj_discriminator.load_state_dict(checkpoint['d_obj_state'])
      optimizer_d_obj.load_state_dict(checkpoint['d_obj_optim_state'])

    if img_discriminator is not None:

      img_discriminator.load_state_dict(checkpoint['d_img_state'])
      optimizer_d_img.load_state_dict(checkpoint['d_img_optim_state'])

    t = checkpoint['counters']['t']
    if 0 <= args.eval_mode_after <= t:
      model.eval()
    else:
      model.train()
    epoch = checkpoint['counters']['epoch']
  else:
    t, epoch = 0, 0
    checkpoint = {
      'args': args.__dict__,
      'vocab': vocab,
      'model_kwargs': model_kwargs,
      'd_obj_kwargs': d_obj_kwargs,
      'd_img_kwargs': d_img_kwargs,
      'losses_ts': [],
      'losses': defaultdict(list),
      'd_losses': defaultdict(list),
      'checkpoint_ts': [],
      'train_batch_data': [], 
      'train_samples': [],
      'train_iou': [],
      'val_batch_data': [], 
      'val_samples': [],
      'val_losses': defaultdict(list),
      'val_iou': [], 
      'norm_d': [], 
      'norm_g': [],
      'counters': {
        't': None,
        'epoch': None,
      },
      'model_state': None, 'model_best_state': None, 'optim_state': None,
      'd_obj_state': None, 'd_obj_best_state': None, 'd_obj_optim_state': None,
      'd_img_state': None, 'd_img_best_state': None, 'd_img_optim_state': None,
      'best_t': [],
    }


  while True:
    if t >= args.num_iterations:
      break
    epoch += 1
    print('Starting epoch %d' % epoch)
   
    for batch in train_loader:
      if t == args.eval_mode_after:
        print('switching to eval mode')
        model.eval()
        optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
      t += 1
      batch = [tensor.cuda() for tensor in batch]
      masks = None
      if len(batch) == 6:
        imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch
        triplet_masks = None 
      elif len(batch) == 8:
        imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img, triplet_masks = batch
      #elif len(batch) == 7:
      #  imgs, objs, boxes, masks, triples, obj_to_img, triple_to_img = batch
      else:
        assert False
      predicates = triples[:, 1]

      with timeit('forward', args.timing):
        model_boxes = boxes
        model_masks = masks
        model_out = model(objs, triples, obj_to_img,
                          boxes_gt=model_boxes, masks_gt=model_masks
                          )
        # imgs_pred, boxes_pred, masks_pred, predicate_scores = model_out
        imgs_pred, boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores, obj_embeddings, pred_embeddings, triplet_boxes_pred, triplet_boxes, triplet_masks_pred, boxes_pred_info, triplet_superboxes_pred = model_out
        # boxes_pred, masks_pred, objs_vec, layout, layout_boxes, layout_masks, obj_to_img, sg_context_pred, sg_context_pred_d, predicate_scores = model_out
     
      # add additional information for GT boxes (hack to not change coco.py)
      boxes_info = None
      if args.use_bbox_info and boxes_pred_info is not None:
        boxes_info = add_bbox_info(boxes) 
      # GT for triplet superbox
      triplet_superboxes = None
      if args.triplet_superbox_net and triplet_superboxes_pred is not None:
        # triplet_boxes = [ x1_0 y1_0 x1_1 y1_1 x2_0 y2_0 x2_1 y2_1]
        min_pts = triplet_boxes[:,:2]
        max_pts = triplet_boxes[:,6:8]
        triplet_superboxes = torch.cat([min_pts, max_pts], dim=1)

      with timeit('loss', args.timing):
        # Skip the pixel loss if using GT boxes
        #skip_pixel_loss = (model_boxes is None)
        skip_pixel_loss = True 
        # Skip the perceptual loss if using GT boxes
        #skip_perceptual_loss = (model_boxes is None)
        skip_perceptual_loss = True 

        if args.perceptual_loss_weight:
          total_loss, losses =  calculate_model_losses(
                                  args, skip_pixel_loss, model, imgs, mgs_pred,
                                  boxes, boxes_pred, masks, masks_pred,
                                  boxes_info, boxes_pred_info,
                                  predicates, predicate_scores, 
                                  triplet_boxes, triplet_boxes_pred, 
                                  triplet_masks, triplet_masks_pred, 
                                  triplet_superboxes, triplet_superboxes_pred,
                                  skip_perceptual_loss,
                                  perceptual_extractor=vgg_featureExractor)
        else:
          total_loss, losses =  calculate_model_losses(
                                    args, skip_pixel_loss, model, imgs, imgs_pred,
                                    boxes, boxes_pred, masks, masks_pred,
                                    boxes_info, boxes_pred_info,
                                    predicates, predicate_scores, 
                                    triplet_boxes, triplet_boxes_pred, 
                                    triplet_masks, triplet_masks_pred,
                                    triplet_superboxes, triplet_superboxes_pred,
                                    skip_perceptual_loss)  
          #total_loss, losses =  calculate_model_losses(
          #                          args, skip_pixel_loss, model, imgs, imgs_pred,
          #                          boxes, boxes_pred, masks, masks_pred,
          #                          predicates, predicate_scores, 

      if obj_discriminator is not None:
        scores_fake, ac_loss = obj_discriminator(imgs_pred, objs, boxes, obj_to_img)
        total_loss = add_loss(total_loss, ac_loss, losses, 'ac_loss',
                              args.ac_loss_weight)
        weight = args.discriminator_loss_weight * args.d_obj_weight
        total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                              'g_gan_obj_loss', weight)

      if img_discriminator is not None:
        weight = args.discriminator_loss_weight * args.d_img_weight
      
        # scene_graph context by pooled GCNN features
        if args.sg_context_dim > 0:
          ## concatenate => imgs, (layout_embedding), sg_context_pred
          if args.layout_for_discrim == 1:
            discrim_pred = torch.cat([imgs_pred, layout, sg_context_pred_d], dim=1) 
          else:
            discrim_pred = torch.cat([imgs_pred, sg_context_pred_d], dim=1)  
          
          if args.matching_aware_loss:
            # shuffle sg_context_p to use addional fake examples with real-images
            matching_aware_size = sg_context_pred_d.size()[0]
            s_sg_context_pred_d = sg_context_pred_d[torch.randperm(matching_aware_size)]  
            if args.layout_for_discrim == 1:
              match_aware_discrim_pred = torch.cat([imgs, layout, s_sg_context_pred_d], dim=1) 
            else: 
              match_aware_discrim_pred = torch.cat([imgs, s_sg_context_pred_d], dim=1 )           
            discrim_pred = torch.cat([discrim_pred, match_aware_discrim_pred], dim=0)         
          
          scores_fake = img_discriminator(discrim_pred)         
          if args.matching_aware_loss:
            total_loss = add_loss(total_loss, gan_g_matching_aware_loss(scores_fake), losses,
                            'g_gan_img_loss', weight)
          else:
            total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                              'g_gan_img_loss', weight)
        else:
          scores_fake = img_discriminator(imgs_pred)
          #weight = args.discriminator_loss_weight * args.d_img_weight
          total_loss = add_loss(total_loss, gan_g_loss(scores_fake), losses,
                                'g_gan_img_loss', weight)

      losses['total_loss'] = total_loss.item()
      if not math.isfinite(losses['total_loss']):
        print('WARNING: Got loss = NaN, not backpropping')
        continue

      optimizer.zero_grad()
      with timeit('backward', args.timing):
        total_loss.backward()
      optimizer.step()
      total_loss_d = None
      ac_loss_real = None
      ac_loss_fake = None
      d_losses = {}
      
      if obj_discriminator is not None:
        d_obj_losses = LossManager()
        imgs_fake = imgs_pred.detach()
        scores_fake, ac_loss_fake = obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
        scores_real, ac_loss_real = obj_discriminator(imgs, objs, boxes, obj_to_img)

        d_obj_gan_loss = gan_d_loss(scores_real, scores_fake)
        d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
        d_obj_losses.add_loss(ac_loss_real, 'd_ac_loss_real')
        d_obj_losses.add_loss(ac_loss_fake, 'd_ac_loss_fake')

        optimizer_d_obj.zero_grad()
        d_obj_losses.total_loss.backward()
        optimizer_d_obj.step()


      if img_discriminator is not None:
        d_img_losses = LossManager()
               
        imgs_fake = imgs_pred.detach()

        if args.sg_context_dim_d > 0:
          sg_context_p = sg_context_pred_d.detach()  
        
        layout_p = layout.detach()
        # layout_gt_p = layout_gt.detach()

        ## concatenate=> imgs_fake, (layout_embedding), sg_context_pred
        if args.sg_context_dim > 0:
          if args.layout_for_discrim:
            discrim_fake = torch.cat([imgs_fake, layout_p, sg_context_p], dim=1 )  
            discrim_real = torch.cat([imgs, layout_p, sg_context_p], dim=1 ) 
            # discrim_real = torch.cat([imgs, layout_gt_p, sg_context_p], dim=1 ) 
          else:   
            discrim_fake = torch.cat([imgs_fake, sg_context_p], dim=1 ) 
            discrim_real = torch.cat([imgs, sg_context_p], dim=1 )   
              
          if args.matching_aware_loss:
            # shuffle sg_context_p to use addional fake examples with real-images
            matching_aware_size = sg_context_p.size()[0]
            s_sg_context_p = sg_context_p[torch.randperm(matching_aware_size)]
            # s_sg_context_p = sg_context_p[torch.randperm(args.batch_size)]
            if args.layout_for_discrim:
              match_aware_discrim_fake = torch.cat([imgs, layout_p, s_sg_context_p], dim=1 ) 
            else:
              match_aware_discrim_fake = torch.cat([imgs, s_sg_context_p], dim=1 ) 
            discrim_fake = torch.cat([discrim_fake, match_aware_discrim_fake], dim=0)   

          scores_fake = img_discriminator(discrim_fake) 
          scores_real = img_discriminator(discrim_real)

          if args.matching_aware_loss:
            d_img_gan_loss = gan_d_matching_aware_loss(scores_real, scores_fake)
          else:
            d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
        else:
          # imgs_fake = imgs_pred.detach()
          scores_fake = img_discriminator(imgs_fake)
          scores_real = img_discriminator(imgs)

        if args.matching_aware_loss:
          d_img_gan_loss = gan_d_matching_aware_loss(scores_real, scores_fake)
        else:
          d_img_gan_loss = gan_d_loss(scores_real, scores_fake)
          
        d_img_losses.add_loss(d_img_gan_loss, 'd_img_gan_loss')
        
        optimizer_d_img.zero_grad()
        d_img_losses.total_loss.backward()
        optimizer_d_img.step()

      # report intermediary values to stdout
      if t % args.print_every == 0:
        print('t = %d / %d' % (t, args.num_iterations))
        for name, val in losses.items():
          print(' G [%s]: %.4f' % (name, val))
          checkpoint['losses'][name].append(val)
        checkpoint['losses_ts'].append(t)

        if obj_discriminator is not None:
          for name, val in d_obj_losses.items():
            print(' D_obj [%s]: %.4f' % (name, val))
            checkpoint['d_losses'][name].append(val)

        if img_discriminator is not None:
          for name, val in d_img_losses.items():
            print(' D_img [%s]: %.4f' % (name, val))
            checkpoint['d_losses'][name].append(val)

        # ================================================================== #
        #                        Tensorboard Logging                         #
        # ================================================================== #

        # 1. Log scalar values (scalar summary)
        for name, val in losses.items():
            logger.scalar_summary(name, val, t)
        if obj_discriminator is not None:    
          for name, val in d_obj_losses.items():
              logger.scalar_summary(name, val, t)           
        if img_discriminator is not None:
          for name, val in d_img_losses.items():
              logger.scalar_summary(name, val, t)   
        logger.scalar_summary('score', t, t)

        if t % args.check_val_metrics_every == 0 and t > args.print_every:
          print('Checking val metrics')
          rel_score, avg_iou = get_rel_score(args, t, val_loader, model)
          logger.scalar_summary('relation_score', rel_score, t) 
          logger.scalar_summary('avg_IoU', avg_iou, t) 
          print(t, ': val relation score = ', rel_score)
          print(t, ': IoU = ', avg_iou)
          # checks entire val dataset
          val_results = check_model(args, t, val_loader, model, logger=logger, log_tag='Val', write_images=False)
          v_losses, v_samples, v_batch_data, v_avg_iou = val_results
          logger.scalar_summary('val_total_loss', v_losses['total_loss'], t)
      
      if t % args.checkpoint_every == 0:
        print('checking on train')
        train_results = check_model(args, t, train_loader, model, logger=logger, log_tag='Train', write_images=False)
        t_losses, t_samples, t_batch_data, t_avg_iou = train_results

        checkpoint['train_batch_data'].append(t_batch_data)
        checkpoint['train_samples'].append(t_samples)
        checkpoint['checkpoint_ts'].append(t)
        checkpoint['train_iou'].append(t_avg_iou)

        print('checking on val')
        val_results = check_model(args, t, val_loader, model, logger=logger, log_tag='Validation', write_images=True)

        val_losses, val_samples, val_batch_data, val_avg_iou = val_results
        checkpoint['val_samples'].append(val_samples)
        checkpoint['val_batch_data'].append(val_batch_data)
        checkpoint['val_iou'].append(val_avg_iou)
        
        print('train iou: ', t_avg_iou)
        print('val iou: ', val_avg_iou)

        for k, v in val_losses.items():
          checkpoint['val_losses'][k].append(v)
        checkpoint['model_state'] = model.state_dict()

        if obj_discriminator is not None:
          checkpoint['d_obj_state'] = obj_discriminator.state_dict()
          checkpoint['d_obj_optim_state'] = optimizer_d_obj.state_dict()

        if img_discriminator is not None:
          checkpoint['d_img_state'] = img_discriminator.state_dict()
          checkpoint['d_img_optim_state'] = optimizer_d_img.state_dict()

        checkpoint['optim_state'] = optimizer.state_dict()
        checkpoint['counters']['t'] = t
        checkpoint['counters']['epoch'] = epoch
        checkpoint_path = os.path.join(args.output_dir,
                              '%s_with_model_%d.pt' %(args.checkpoint_name, t) 
                              #'%s_with_model.pt' %args.checkpoint_name
                              )
        print('Saving checkpoint to ', checkpoint_path)
        torch.save(checkpoint, checkpoint_path)

        # Save another checkpoint without any model or optim state
        checkpoint_path = os.path.join(args.output_dir,
                              '%s_no_model.pt' % args.checkpoint_name)
        key_blacklist = ['model_state', 'optim_state', 'model_best_state',
                         'd_obj_state', 'd_obj_optim_state', 'd_obj_best_state',
                         'd_img_state', 'd_img_optim_state', 'd_img_best_state']
        small_checkpoint = {}
        for k, v in checkpoint.items():
          if k not in key_blacklist:
            small_checkpoint[k] = v
        torch.save(small_checkpoint, checkpoint_path)
Esempio n. 9
0
        if t % (args.n_critic + 1) != 0:
            if obj_discriminator is not None:
                imgs_fake = imgs_pred.detach()
                with timeit('d_obj forward for d', args.timing):
                    d_scores_fake_crop, d_obj_scores_fake_crop, fake_crops, d_rec_feature_fake_crop = \
                        obj_discriminator(imgs_fake, objs, boxes, obj_to_img)
                    d_scores_real_crop, d_obj_scores_real_crop, real_crops, d_rec_feature_real_crop = \
                        obj_discriminator(imgs, objs, boxes, obj_to_img)
                    if args.gan_loss_type == "wgan-gp":
                        d_obj_gp = gradient_penalty(
                            real_crops.detach(), fake_crops.detach(),
                            obj_discriminator.discriminator)

                ## train d
                with timeit('d_obj loss', args.timing):
                    d_obj_losses = LossManager()
                    if args.d_obj_weight > 0:
                        d_obj_gan_loss = gan_d_loss(d_obj_scores_real_crop,
                                                    d_obj_scores_fake_crop)
                        d_obj_losses.add_loss(d_obj_gan_loss, 'd_obj_gan_loss')
                        if args.gan_loss_type == 'wgan-gp':
                            d_obj_losses.add_loss(d_obj_gp.mean(), 'd_obj_gp',
                                                  args.d_obj_gp_weight)
                    if args.ac_loss_weight > 0:
                        d_obj_losses.add_loss(
                            F.cross_entropy(d_obj_scores_real_crop, objs),
                            'd_ac_loss_real')
                        d_obj_losses.add_loss(
                            F.cross_entropy(d_obj_scores_fake_crop, objs),
                            'd_ac_loss_fake')