Esempio n. 1
0
        fasterRCNN = nn.DataParallel(fasterRCNN)

    iters_per_epoch = int(train_size / args.batch_size)

    if args.use_tfboard:
        from tensorboardX import SummaryWriter
        logger = SummaryWriter("logs")

    for epoch in range(args.start_epoch, args.max_epochs + 1):
        # setting to train mode
        fasterRCNN.train()
        loss_temp = 0
        start = time.time()

        if epoch % (args.lr_decay_step + 1) == 0:
            adjust_learning_rate(optimizer, args.lr_decay_gamma)
            lr *= args.lr_decay_gamma

        data_iter = iter(dataloader)
        for step in range(iters_per_epoch):
            data = next(data_iter)
            im_data.data.resize_(data[0].size()).copy_(data[0])
            im_info.data.resize_(data[1].size()).copy_(data[1])
            gt_boxes.data.resize_(data[2].size()).copy_(data[2])
            num_boxes.data.resize_(data[3].size()).copy_(data[3])

            fasterRCNN.zero_grad()
            rois, cls_prob, bbox_pred, \
                rpn_loss_cls, rpn_loss_box, \
                RCNN_loss_cls, RCNN_loss_bbox, \
                rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes)
Esempio n. 2
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if args.target_only:
        source_train_dataset = TDETDataset(['voc07_trainval'],
                                           args.data_dir,
                                           args.prop_method,
                                           num_classes=20,
                                           prop_min_scale=args.prop_min_scale,
                                           prop_topk=args.num_prop)
    else:
        source_train_dataset = TDETDataset(
            ['coco60_train2014', 'coco60_val2014'],
            args.data_dir,
            args.prop_method,
            num_classes=60,
            prop_min_scale=args.prop_min_scale,
            prop_topk=args.num_prop)
    target_val_dataset = TDETDataset(['voc07_test'],
                                     args.data_dir,
                                     args.prop_method,
                                     num_classes=20,
                                     prop_min_scale=args.prop_min_scale,
                                     prop_topk=args.num_prop)

    lr = args.lr

    if args.net == 'DC_VGG16_DET':
        base_model = DC_VGG16_CLS(None, 20 if args.target_only else 80, 3, 4)
        checkpoint = torch.load(args.pretrained_base_path)
        base_model.load_state_dict(checkpoint['model'])
        del checkpoint
        model = DC_VGG16_DET(base_model, args.pooling_method)

    optimizer = model.get_optimizer(args.lr)

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    source_pos_prop_sum = 0
    source_neg_prop_sum = 0
    start = time.time()
    optimizer.zero_grad()
    for step in range(args.start_iter, args.max_iter + 1):
        if step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]

        source_batch = source_train_dataset.get_data(
            source_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
        source_proposals = source_batch['proposals']
        source_gt_boxes = source_batch['gt_boxes']
        if args.target_only:
            source_gt_labels = source_batch['gt_labels']
        else:
            source_gt_labels = source_batch['gt_labels'] + 20
        source_pos_cls = [i for i in range(80) if i in source_gt_labels]

        source_loss = 0
        for cls in np.random.choice(source_pos_cls, 2):
            indices = np.where(source_gt_labels.numpy() == cls)[0]
            here_gt_boxes = source_gt_boxes[indices]
            here_proposals, here_labels, _, pos_cnt, neg_cnt = sample_proposals(
                here_gt_boxes, source_proposals, args.bs // 2, args.pos_ratio)
            # plt.imshow(source_batch['raw_img'])
            # draw_box(here_proposals[:pos_cnt] / source_batch['im_scale'], 'black')
            # draw_box(here_proposals[pos_cnt:] / source_batch['im_scale'], 'yellow')
            # plt.show()
            here_proposals = here_proposals.to(device)
            here_labels = here_labels.to(device)
            here_loss = model(source_im_data, cls, here_proposals, here_labels)
            source_loss = source_loss + here_loss

            source_pos_prop_sum += pos_cnt
            source_neg_prop_sum += neg_cnt

        source_loss = source_loss / 2

        source_loss_sum += source_loss.item()
        source_loss.backward()

        clip_gradient(model, 10.0)
        optimizer.step()
        optimizer.zero_grad()

        if step % args.disp_interval == 0:
            end = time.time()
            source_loss_sum /= args.disp_interval
            source_pos_prop_sum /= args.disp_interval
            source_neg_prop_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, source_loss_sum, source_pos_prop_sum, source_neg_prop_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            source_loss_sum = 0
            source_pos_prop_sum = 0
            source_neg_prop_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            validate(model, target_val_dataset, args, device)
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['pooling_method'] = args.pooling_method
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
Esempio n. 3
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'],
                                       args.data_dir,
                                       args.prop_method,
                                       num_classes=60,
                                       prop_min_scale=args.prop_min_scale,
                                       prop_topk=args.num_prop)
    target_train_dataset = TDETDataset(['voc07_trainval'],
                                       args.data_dir,
                                       args.prop_method,
                                       num_classes=20,
                                       prop_min_scale=args.prop_min_scale,
                                       prop_topk=args.num_prop)

    lr = args.lr

    if args.net == 'NEW_TDET':
        model = NEW_TDET(os.path.join(args.data_dir,
                                      'pretrained_model/vgg16_caffe.pth'),
                         20,
                         pooling_method=args.pooling_method,
                         share_level=args.share_level,
                         mil_topk=args.mil_topk)
    else:
        raise Exception('network is not defined')

    optimizer = model.get_optimizer(args.lr)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkiter))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_iter = checkpoint['iterations'] + 1
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    target_loss_sum = 0
    source_pos_prop_sum = 0
    source_neg_prop_sum = 0
    target_prop_sum = 0
    start = time.time()
    for step in range(args.start_iter, args.max_iter + 1):
        if step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))
        if step % len(target_train_dataset) == 1:
            target_rand_perm = np.random.permutation(len(target_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]
        target_index = target_rand_perm[step % len(target_train_dataset)]

        source_batch = source_train_dataset.get_data(
            source_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))
        target_batch = target_train_dataset.get_data(
            target_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
        source_proposals = source_batch['proposals']
        source_gt_boxes = source_batch['gt_boxes']
        source_proposals, source_labels, _, pos_cnt, neg_cnt = sample_proposals(
            source_gt_boxes, source_proposals, args.bs, args.pos_ratio)
        source_proposals = source_proposals.to(device)
        source_gt_boxes = source_gt_boxes.to(device)
        source_labels = source_labels.to(device)

        target_im_data = target_batch['im_data'].unsqueeze(0).to(device)
        target_proposals = target_batch['proposals'].to(device)
        target_image_level_label = target_batch['image_level_label'].to(device)

        optimizer.zero_grad()

        # source forward & backward
        _, source_loss = model.forward_det(source_im_data, source_proposals,
                                           source_labels)
        source_loss_sum += source_loss.item()
        source_loss = source_loss * (1 - args.alpha)
        source_loss.backward()

        # target forward & backward
        if args.cam_like:
            _, target_loss = model.forward_cls_camlike(
                target_im_data, target_proposals, target_image_level_label)
        else:
            _, target_loss = model.forward_cls(target_im_data,
                                               target_proposals,
                                               target_image_level_label)
        target_loss_sum += target_loss.item()
        target_loss = target_loss * args.alpha
        target_loss.backward()

        clip_gradient(model, 10.0)
        optimizer.step()
        source_pos_prop_sum += pos_cnt
        source_neg_prop_sum += neg_cnt
        target_prop_sum += target_proposals.size(0)

        if step % args.disp_interval == 0:
            end = time.time()
            loss_sum = source_loss_sum * (
                1 - args.alpha) + target_loss_sum * args.alpha
            loss_sum /= args.disp_interval
            source_loss_sum /= args.disp_interval
            target_loss_sum /= args.disp_interval
            source_pos_prop_sum /= args.disp_interval
            source_neg_prop_sum /= args.disp_interval
            target_prop_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.4f, src_loss: %.4f, tar_loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, tar_prop: %.1f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, loss_sum, source_loss_sum, target_loss_sum, source_pos_prop_sum, source_neg_prop_sum, target_prop_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            source_loss_sum = 0
            target_loss_sum = 0
            source_pos_prop_sum = 0
            source_neg_prop_sum = 0
            target_prop_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['pooling_method'] = args.pooling_method
            checkpoint['share_level'] = args.share_level
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
Esempio n. 4
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)

    np.random.seed(3)
    torch.manual_seed(4)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(5)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    train_dataset = WSDDNDataset(dataset_names=['voc07_trainval'],
                                 data_dir=args.data_dir,
                                 prop_method=args.prop_method,
                                 num_classes=20,
                                 min_prop_scale=args.min_prop)

    lr = args.lr

    if args.net == 'WSDDN_VGG16':
        model = WSDDN_VGG16(
            os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth'),
            20)

    else:
        raise Exception('network is not defined')

    params = []
    for key, value in dict(model.named_parameters()).items():
        if value.requires_grad:
            if 'bias' in key:
                params += [{
                    'params': [value],
                    'lr': lr * 2,
                    'weight_decay': 0
                }]
            else:
                params += [{
                    'params': [value],
                    'lr': lr,
                    'weight_decay': 0.0005
                }]

    optimizer = torch.optim.SGD(params, momentum=0.9)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkepoch))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_epoch = checkpoint['epoch']
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        lr = optimizer.param_groups[0]['lr']
        print("loaded checkpoint %s" % (load_name))

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)

    for epoch in range(args.start_epoch, args.max_epochs + 1):
        model.train()
        loss_sum = 0
        reg_sum = 0
        iter_sum = 0
        num_prop = 0
        start = time.time()

        optimizer.zero_grad()
        rand_perm = np.random.permutation(len(train_dataset))
        for step in range(1, len(train_dataset) + 1):
            index = rand_perm[step - 1]
            apply_h_flip = np.random.rand() > 0.5
            target_im_size = np.random.choice([480, 576, 688, 864, 1200])
            im_data, gt_boxes, box_labels, proposals, prop_scores, image_level_label, im_scale, raw_img, im_id = \
                train_dataset.get_data(index, apply_h_flip, target_im_size)

            # plt.imshow(raw_img)
            # draw_box(proposals / im_scale)
            # draw_box(gt_boxes / im_scale, 'black')
            # plt.show()

            im_data = im_data.unsqueeze(0).to(device)
            rois = proposals.to(device)
            image_level_label = image_level_label.to(device)

            if args.use_prop_score:
                prop_scores = prop_scores.to(device)
            else:
                prop_scores = None
            scores, loss, reg = model(im_data, rois, prop_scores,
                                      image_level_label)
            reg = reg * args.alpha
            num_prop += proposals.size(0)
            loss_sum += loss.item()
            reg_sum += reg.item()
            loss = loss + reg
            if args.bavg:
                loss = loss / args.bs
            loss.backward()

            if step % args.bs == 0:
                optimizer.step()
                optimizer.zero_grad()
            iter_sum += 1

            if step % args.disp_interval == 0:
                end = time.time()

                print(
                    "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, reg: %.4f, num_prop: %.1f, lr: %.2e, time: %.1f"
                    %
                    (args.net, args.session, epoch, step, loss_sum / iter_sum,
                     reg_sum / iter_sum, num_prop / iter_sum, lr, end - start))
                log_file.write(
                    "[net %s][session %d][epoch %2d][iter %4d] loss: %.4f, reg: %.4f, num_prop: %.1f, lr: %.2e, time: %.1f\n"
                    %
                    (args.net, args.session, epoch, step, loss_sum / iter_sum,
                     reg_sum / iter_sum, num_prop / iter_sum, lr, end - start))
                loss_sum = 0
                reg_sum = 0
                num_prop = 0
                iter_sum = 0
                start = time.time()

        log_file.flush()
        if epoch == 10:
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if epoch % args.save_interval == 0:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  epoch))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['epoch'] = epoch + 1
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()
Esempio n. 5
0
def main():
  ## CALL ARGUMENTS
  args = arguments()
  print("Called Arguments")
  
  ## LOG PATH
  date = datetime.datetime.now()
  day = date.strftime('%m%d_%H%M')
  log_path = "./logs/{}".format(day)
  if not os.path.exists(log_path):
    os.mkdir(log_path)

  ## INITIALIZE TENSORBOARD
  if args.tfboard:
    from utils.logger import Logger
    logger = Logger(log_path)

  ## CONFIG SAVE AS TEXT
  configs = "Dataset: {}\nLSTM Size: {}\nNumber of Proposals: {}\nStart Learning Rate: {}\nLearning Rate Decay: {}\nOptimizer: {}\nScore Threshold: {}\nEncoding: {}".format(args.dataset, args.lstm_size, args.num_prop_after, args.learning_rate, args.learning_rate_decay, args.optimizer, args.score_thresh, args.encoding)
  txt_file = "{}/configs.txt".format(log_path)
  f = open(txt_file, 'w')
  f.write(configs)
  f.close()

  ## DATASET CONFIGURATION
  if args.dataset == 'cvpr19':
    args.anchor_scale = [[400, 300, 200, 100]]

  ## CUDA CHECK
  if not torch.cuda.is_available():
    print("WARNING: Why don't you use CUDA?")

  ## DATALOADER (ITERATOR)
  data_type = ['train', 'val', 'test']
  #data_type = ['val']
  loader = {}
  for type in data_type:
    EEGDetectionData = loaddata.EEGDetectionDataset(args, type) # Shape of data : (seq_len, num_ch) or (1, seq_len, num_ch)??
    loader[type] = DataLoader(dataset=EEGDetectionData, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers) # Already shuffled

  seq_len, num_ch = EEGDetectionData[0][0].shape[0], EEGDetectionData[0][0].shape[1]
  args.seq_len, args.num_ch = seq_len, num_ch

  ## ENCODING OR NOT
  if args.encoding:
    args.encoding_size = int((args.num_ch) / (args.encoding_scale))

  ## CALL MODEL
  model = rlstm(args)
  model.create_architecture()

  ## OPTIMIZER
  #params = []
  #for key, value in dict(model.named_parameters()).items():
  #  if value.requires_grad:
  #    params += [{'params':[value],'lr':lr, 'weight_decay': args.weight_decay}]
  lr = args.learning_rate
  optimizer = getattr(torch.optim, args.optimizer)(model.parameters(), lr = args.learning_rate)

  #if args.cuda:
  model.cuda() # CUDA
  ## RESUME

  if args.resume:
    checkpoint = torch.load('./logs/0524_2245/save_model/thecho7_25.pth')
    args.start_epoch = checkpoint['epoch']
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    print("Resume the training")

  ## TRAINING
  for epoch in range(args.start_epoch, args.epochs):
    
    acc = {"train": 0, "val": 0, "test": 0}
    counts = {"train": 0, "val": 0, "test": 0}

    # TRAIN MODE
    model.train()
    loss_check = 0
    start = time.time()

    # Learning Rate Adjustment
    if epoch % (args.checkpoint_interval) == 0:
      adjust_learning_rate(optimizer, args.weight_decay)
      lr *= args.learning_rate_decay

    # Session mode - Train or Test
    for split in data_type:
      if split == 'train':
        args.training = True
        model.train()
      else:
        args.training = False
        model.eval()

      for i, data in enumerate(loader[split], 0):
        print(split)
        # READ BATCH DATA (IN OUR CASE, BATCH SIZE IS 1)
        inputs, labels = data
        inputs, proposals, labels = proposal_gen(inputs, labels, args)
        inputs = inputs.cuda()
        #labels = labels.cuda(async = True)
        labels = labels.cuda()
        inputs = Variable(inputs, volatile = (split != "train"))
        labels = Variable(labels, volatile = (split != "train"))
        
        # FORWARD
        cls_loss, bbox_loss, acc = model(inputs, labels, proposals, split, acc)
        '''
        if split == 'train':
          cls_loss = F.cross_entropy(cls_feat, labels.long()) # F.cross_entropy converts indices automatically
        else:
          cls_loss = Variable(torch.zeros(1).cuda()) # Garbage Value

        # Penalized Loss (Division Method)
        loss_div = 1
        for j in range(args.num_prop_after):
          if int(cls_feat.data.max(1)[1][j]) == int(labels[j]):
            loss_div += 1
            acc[split] += 1
        cls_loss = cls_loss.div(loss_div)

        # Result Print
        _, cls_idx = cls_feat.data.max(1)
        result_print = []
        for j in range(args.num_prop_after):
          result_print.append(int(cls_idx[j]))
        print("  Result labels: {}".format(result_print))
        '''
        # BACKWARD
        loss = cls_loss.mean() + bbox_loss.mean()
        loss_check += loss.data[0]
        if split == 'train':
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
          print("CLASS LOSS: {} BBOX LOSS: {}".format(float(loss.data[0]), float(bbox_loss.data[0])))

        counts[split] += args.num_prop_after

        # LOSS DISPLAY
        if i % args.checkpoint_interval == 1:
          end = time.time()
          if i > 1:
            loss_check /= args.checkpoint_interval
            print("[epoch: {} - loss_check: {}]".format(epoch, loss_check))
            print("[Iter Accuracy: {}".format(acc["train"]/counts["train"]))
          loss_cls = cls_loss.data[0]

          if args.tfboard:
            info = {'loss':loss_check, 'loss_cls':loss_cls}
            for tag, value in info.items():
              logger.scalar_summary(tag, value, i)

          loss_check = 0
          start = time.time()

    # Print info at the end of the epoch
    print("Epoch {}: TrA={:.4f}, VA={:.4f}, TeA={:.4f}".format(epoch, acc["train"]/counts["train"], acc["val"]/counts["val"], acc["test"]/counts["test"]))

    ## SAVE MODEL (TBI)
    model_path = "{}/{}".format(log_path, "save_model")
    if not os.path.exists(model_path):
      os.mkdir(model_path)
    save_name = os.path.join('{}'.format(model_path), 'thecho7_{}.pth'.format(epoch))
    save_checkpoint({
      'epoch': epoch,
      'model': model.state_dict(),
      'optimizer': optimizer.state_dict(),
    }, save_name)
    print('Saving Model: {}......'.format(save_name))
Esempio n. 6
0
def train():
    args = parse_args()
    print('Called with args:')
    print(args)
    assert args.bs % 2 == 0

    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(args.seed)
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')
    print(device)
    output_dir = args.save_dir
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    target_only = args.target_only
    source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'],
                                       args.data_dir,
                                       'eb',
                                       num_classes=60)
    target_train_dataset = TDETDataset(['voc07_trainval'],
                                       args.data_dir,
                                       'eb',
                                       num_classes=20)

    lr = args.lr

    if args.net == 'CAM_DET':
        model = CamDet(
            os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth')
            if not args.resume else None, 20 if target_only else 80,
            args.hidden_dim)
    else:
        raise Exception('network is not defined')

    optimizer = model.get_optimizer(args.lr)

    if args.resume:
        load_name = os.path.join(
            output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession,
                                              args.checkiter))
        print("loading checkpoint %s" % (load_name))
        checkpoint = torch.load(load_name)
        assert args.net == checkpoint['net']
        args.start_iter = checkpoint['iterations'] + 1
        model.load_state_dict(checkpoint['model'])
        print("loaded checkpoint %s" % (load_name))
        del checkpoint

    log_file_name = os.path.join(
        output_dir, 'log_{}_{}.txt'.format(args.net, args.session))
    if args.resume:
        log_file = open(log_file_name, 'a')
    else:
        log_file = open(log_file_name, 'w')
    log_file.write(str(args))
    log_file.write('\n')

    model.to(device)
    model.train()
    source_loss_sum = 0
    target_loss_sum = 0
    total_loss_sum = 0
    start = time.time()
    source_rand_perm = None
    target_rand_perm = None
    for step in range(args.start_iter, args.max_iter + 1):
        if source_rand_perm is None or step % len(source_train_dataset) == 1:
            source_rand_perm = np.random.permutation(len(source_train_dataset))
        if target_rand_perm is None or step % len(target_train_dataset) == 1:
            target_rand_perm = np.random.permutation(len(target_train_dataset))

        source_index = source_rand_perm[step % len(source_train_dataset)]
        target_index = target_rand_perm[step % len(target_train_dataset)]

        optimizer.zero_grad()
        if not target_only:
            source_batch = source_train_dataset.get_data(
                source_index,
                h_flip=np.random.rand() > 0.5,
                target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

            source_im_data = source_batch['im_data'].unsqueeze(0).to(device)
            source_gt_labels = source_batch['gt_labels'] + 20
            source_pos_cls = [i for i in range(80) if i in source_gt_labels]
            source_pos_cls = torch.tensor(np.random.choice(
                source_pos_cls,
                min(args.bs, len(source_pos_cls)),
                replace=False),
                                          dtype=torch.long,
                                          device=device)

            source_loss, _, _ = model(source_im_data, source_pos_cls)
            source_loss_sum += source_loss.item()

        target_batch = target_train_dataset.get_data(
            target_index,
            h_flip=np.random.rand() > 0.5,
            target_im_size=np.random.choice([480, 576, 688, 864, 1200]))

        target_im_data = target_batch['im_data'].unsqueeze(0).to(device)
        target_gt_labels = target_batch['gt_labels']
        target_pos_cls = [i for i in range(80) if i in target_gt_labels]
        target_pos_cls = torch.tensor(np.random.choice(
            target_pos_cls, min(args.bs, len(target_pos_cls)), replace=False),
                                      dtype=torch.long,
                                      device=device)

        target_loss, _, _, _ = model(target_im_data, target_pos_cls)
        target_loss_sum += target_loss.item()
        if args.target_only:
            total_loss = target_loss
        else:
            total_loss = (source_loss + target_loss) * 0.5
        total_loss.backward()
        total_loss_sum += total_loss.item()
        clip_gradient(model, 10.0)
        optimizer.step()

        if step % args.disp_interval == 0:
            end = time.time()
            total_loss_sum /= args.disp_interval
            source_loss_sum /= args.disp_interval
            target_loss_sum /= args.disp_interval
            log_message = "[%s][session %d][iter %4d] loss: %.8f, src_loss: %.8f, tar_loss: %.8f, lr: %.2e, time: %.1f" % \
                          (args.net, args.session, step, total_loss_sum, source_loss_sum, target_loss_sum, lr, end - start)
            print(log_message)
            log_file.write(log_message + '\n')
            log_file.flush()
            total_loss_sum = 0
            source_loss_sum = 0
            target_loss_sum = 0
            start = time.time()

        if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7):
            adjust_learning_rate(optimizer, 0.1)
            lr *= 0.1

        if step % args.save_interval == 0 or step == args.max_iter:
            save_name = os.path.join(
                output_dir, '{}_{}_{}.pth'.format(args.net, args.session,
                                                  step))
            checkpoint = dict()
            checkpoint['net'] = args.net
            checkpoint['session'] = args.session
            checkpoint['iterations'] = step
            checkpoint['model'] = model.state_dict()

            save_checkpoint(checkpoint, save_name)
            print('save model: {}'.format(save_name))

    log_file.close()