fasterRCNN.zero_grad() rois, cls_prob, bbox_pred, \ rpn_loss_cls, rpn_loss_box, \ RCNN_loss_cls, RCNN_loss_bbox, \ rois_label = fasterRCNN(im_data, im_info, gt_boxes, num_boxes) loss = rpn_loss_cls.mean() + rpn_loss_box.mean() \ + RCNN_loss_cls.mean() + RCNN_loss_bbox.mean() loss_temp += loss.item() # backward optimizer.zero_grad() loss.backward() if args.net == "vgg16": clip_gradient(fasterRCNN, 10.) optimizer.step() if step % args.disp_interval == 0: end = time.time() if step > 0: loss_temp /= (args.disp_interval + 1) if args.mGPUs: loss_rpn_cls = rpn_loss_cls.mean().item() loss_rpn_box = rpn_loss_box.mean().item() loss_rcnn_cls = RCNN_loss_cls.mean().item() loss_rcnn_box = RCNN_loss_bbox.mean().item() fg_cnt = torch.sum(rois_label.data.ne(0)) bg_cnt = rois_label.data.numel() - fg_cnt else:
def train(): args = parse_args() print('Called with args:') print(args) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) device = torch.device('cuda') else: device = torch.device('cpu') output_dir = args.save_dir if not os.path.exists(output_dir): os.makedirs(output_dir) source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'], args.data_dir, args.prop_method, num_classes=60, prop_min_scale=args.prop_min_scale, prop_topk=args.num_prop) target_train_dataset = TDETDataset(['voc07_trainval'], args.data_dir, args.prop_method, num_classes=20, prop_min_scale=args.prop_min_scale, prop_topk=args.num_prop) lr = args.lr if args.net == 'NEW_TDET': model = NEW_TDET(os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth'), 20, pooling_method=args.pooling_method, share_level=args.share_level, mil_topk=args.mil_topk) else: raise Exception('network is not defined') optimizer = model.get_optimizer(args.lr) if args.resume: load_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession, args.checkiter)) print("loading checkpoint %s" % (load_name)) checkpoint = torch.load(load_name) assert args.net == checkpoint['net'] args.start_iter = checkpoint['iterations'] + 1 model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) lr = optimizer.param_groups[0]['lr'] print("loaded checkpoint %s" % (load_name)) log_file_name = os.path.join( output_dir, 'log_{}_{}.txt'.format(args.net, args.session)) if args.resume: log_file = open(log_file_name, 'a') else: log_file = open(log_file_name, 'w') log_file.write(str(args)) log_file.write('\n') model.to(device) model.train() source_loss_sum = 0 target_loss_sum = 0 source_pos_prop_sum = 0 source_neg_prop_sum = 0 target_prop_sum = 0 start = time.time() for step in range(args.start_iter, args.max_iter + 1): if step % len(source_train_dataset) == 1: source_rand_perm = np.random.permutation(len(source_train_dataset)) if step % len(target_train_dataset) == 1: target_rand_perm = np.random.permutation(len(target_train_dataset)) source_index = source_rand_perm[step % len(source_train_dataset)] target_index = target_rand_perm[step % len(target_train_dataset)] source_batch = source_train_dataset.get_data( source_index, h_flip=np.random.rand() > 0.5, target_im_size=np.random.choice([480, 576, 688, 864, 1200])) target_batch = target_train_dataset.get_data( target_index, h_flip=np.random.rand() > 0.5, target_im_size=np.random.choice([480, 576, 688, 864, 1200])) source_im_data = source_batch['im_data'].unsqueeze(0).to(device) source_proposals = source_batch['proposals'] source_gt_boxes = source_batch['gt_boxes'] source_proposals, source_labels, _, pos_cnt, neg_cnt = sample_proposals( source_gt_boxes, source_proposals, args.bs, args.pos_ratio) source_proposals = source_proposals.to(device) source_gt_boxes = source_gt_boxes.to(device) source_labels = source_labels.to(device) target_im_data = target_batch['im_data'].unsqueeze(0).to(device) target_proposals = target_batch['proposals'].to(device) target_image_level_label = target_batch['image_level_label'].to(device) optimizer.zero_grad() # source forward & backward _, source_loss = model.forward_det(source_im_data, source_proposals, source_labels) source_loss_sum += source_loss.item() source_loss = source_loss * (1 - args.alpha) source_loss.backward() # target forward & backward if args.cam_like: _, target_loss = model.forward_cls_camlike( target_im_data, target_proposals, target_image_level_label) else: _, target_loss = model.forward_cls(target_im_data, target_proposals, target_image_level_label) target_loss_sum += target_loss.item() target_loss = target_loss * args.alpha target_loss.backward() clip_gradient(model, 10.0) optimizer.step() source_pos_prop_sum += pos_cnt source_neg_prop_sum += neg_cnt target_prop_sum += target_proposals.size(0) if step % args.disp_interval == 0: end = time.time() loss_sum = source_loss_sum * ( 1 - args.alpha) + target_loss_sum * args.alpha loss_sum /= args.disp_interval source_loss_sum /= args.disp_interval target_loss_sum /= args.disp_interval source_pos_prop_sum /= args.disp_interval source_neg_prop_sum /= args.disp_interval target_prop_sum /= args.disp_interval log_message = "[%s][session %d][iter %4d] loss: %.4f, src_loss: %.4f, tar_loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, tar_prop: %.1f, lr: %.2e, time: %.1f" % \ (args.net, args.session, step, loss_sum, source_loss_sum, target_loss_sum, source_pos_prop_sum, source_neg_prop_sum, target_prop_sum, lr, end - start) print(log_message) log_file.write(log_message + '\n') log_file.flush() source_loss_sum = 0 target_loss_sum = 0 source_pos_prop_sum = 0 source_neg_prop_sum = 0 target_prop_sum = 0 start = time.time() if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7): adjust_learning_rate(optimizer, 0.1) lr *= 0.1 if step % args.save_interval == 0 or step == args.max_iter: save_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.session, step)) checkpoint = dict() checkpoint['net'] = args.net checkpoint['session'] = args.session checkpoint['pooling_method'] = args.pooling_method checkpoint['share_level'] = args.share_level checkpoint['iterations'] = step checkpoint['model'] = model.state_dict() checkpoint['optimizer'] = optimizer.state_dict() save_checkpoint(checkpoint, save_name) print('save model: {}'.format(save_name)) log_file.close()
def train(): args = parse_args() print('Called with args:') print(args) np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) device = torch.device('cuda') else: device = torch.device('cpu') output_dir = args.save_dir if not os.path.exists(output_dir): os.makedirs(output_dir) if args.target_only: source_train_dataset = TDETDataset(['voc07_trainval'], args.data_dir, args.prop_method, num_classes=20, prop_min_scale=args.prop_min_scale, prop_topk=args.num_prop) else: source_train_dataset = TDETDataset( ['coco60_train2014', 'coco60_val2014'], args.data_dir, args.prop_method, num_classes=60, prop_min_scale=args.prop_min_scale, prop_topk=args.num_prop) target_val_dataset = TDETDataset(['voc07_test'], args.data_dir, args.prop_method, num_classes=20, prop_min_scale=args.prop_min_scale, prop_topk=args.num_prop) lr = args.lr if args.net == 'DC_VGG16_DET': base_model = DC_VGG16_CLS(None, 20 if args.target_only else 80, 3, 4) checkpoint = torch.load(args.pretrained_base_path) base_model.load_state_dict(checkpoint['model']) del checkpoint model = DC_VGG16_DET(base_model, args.pooling_method) optimizer = model.get_optimizer(args.lr) log_file_name = os.path.join( output_dir, 'log_{}_{}.txt'.format(args.net, args.session)) log_file = open(log_file_name, 'w') log_file.write(str(args)) log_file.write('\n') model.to(device) model.train() source_loss_sum = 0 source_pos_prop_sum = 0 source_neg_prop_sum = 0 start = time.time() optimizer.zero_grad() for step in range(args.start_iter, args.max_iter + 1): if step % len(source_train_dataset) == 1: source_rand_perm = np.random.permutation(len(source_train_dataset)) source_index = source_rand_perm[step % len(source_train_dataset)] source_batch = source_train_dataset.get_data( source_index, h_flip=np.random.rand() > 0.5, target_im_size=np.random.choice([480, 576, 688, 864, 1200])) source_im_data = source_batch['im_data'].unsqueeze(0).to(device) source_proposals = source_batch['proposals'] source_gt_boxes = source_batch['gt_boxes'] if args.target_only: source_gt_labels = source_batch['gt_labels'] else: source_gt_labels = source_batch['gt_labels'] + 20 source_pos_cls = [i for i in range(80) if i in source_gt_labels] source_loss = 0 for cls in np.random.choice(source_pos_cls, 2): indices = np.where(source_gt_labels.numpy() == cls)[0] here_gt_boxes = source_gt_boxes[indices] here_proposals, here_labels, _, pos_cnt, neg_cnt = sample_proposals( here_gt_boxes, source_proposals, args.bs // 2, args.pos_ratio) # plt.imshow(source_batch['raw_img']) # draw_box(here_proposals[:pos_cnt] / source_batch['im_scale'], 'black') # draw_box(here_proposals[pos_cnt:] / source_batch['im_scale'], 'yellow') # plt.show() here_proposals = here_proposals.to(device) here_labels = here_labels.to(device) here_loss = model(source_im_data, cls, here_proposals, here_labels) source_loss = source_loss + here_loss source_pos_prop_sum += pos_cnt source_neg_prop_sum += neg_cnt source_loss = source_loss / 2 source_loss_sum += source_loss.item() source_loss.backward() clip_gradient(model, 10.0) optimizer.step() optimizer.zero_grad() if step % args.disp_interval == 0: end = time.time() source_loss_sum /= args.disp_interval source_pos_prop_sum /= args.disp_interval source_neg_prop_sum /= args.disp_interval log_message = "[%s][session %d][iter %4d] loss: %.4f, pos_prop: %.1f, neg_prop: %.1f, lr: %.2e, time: %.1f" % \ (args.net, args.session, step, source_loss_sum, source_pos_prop_sum, source_neg_prop_sum, lr, end - start) print(log_message) log_file.write(log_message + '\n') log_file.flush() source_loss_sum = 0 source_pos_prop_sum = 0 source_neg_prop_sum = 0 start = time.time() if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7): adjust_learning_rate(optimizer, 0.1) lr *= 0.1 if step % args.save_interval == 0 or step == args.max_iter: validate(model, target_val_dataset, args, device) save_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.session, step)) checkpoint = dict() checkpoint['net'] = args.net checkpoint['session'] = args.session checkpoint['pooling_method'] = args.pooling_method checkpoint['iterations'] = step checkpoint['model'] = model.state_dict() save_checkpoint(checkpoint, save_name) print('save model: {}'.format(save_name)) log_file.close()
def train(): args = parse_args() print('Called with args:') print(args) assert args.bs % 2 == 0 np.random.seed(args.seed) torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed(args.seed) device = torch.device('cuda') else: device = torch.device('cpu') print(device) output_dir = args.save_dir if not os.path.exists(output_dir): os.makedirs(output_dir) target_only = args.target_only source_train_dataset = TDETDataset(['coco60_train2014', 'coco60_val2014'], args.data_dir, 'eb', num_classes=60) target_train_dataset = TDETDataset(['voc07_trainval'], args.data_dir, 'eb', num_classes=20) lr = args.lr if args.net == 'CAM_DET': model = CamDet( os.path.join(args.data_dir, 'pretrained_model/vgg16_caffe.pth') if not args.resume else None, 20 if target_only else 80, args.hidden_dim) else: raise Exception('network is not defined') optimizer = model.get_optimizer(args.lr) if args.resume: load_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.checksession, args.checkiter)) print("loading checkpoint %s" % (load_name)) checkpoint = torch.load(load_name) assert args.net == checkpoint['net'] args.start_iter = checkpoint['iterations'] + 1 model.load_state_dict(checkpoint['model']) print("loaded checkpoint %s" % (load_name)) del checkpoint log_file_name = os.path.join( output_dir, 'log_{}_{}.txt'.format(args.net, args.session)) if args.resume: log_file = open(log_file_name, 'a') else: log_file = open(log_file_name, 'w') log_file.write(str(args)) log_file.write('\n') model.to(device) model.train() source_loss_sum = 0 target_loss_sum = 0 total_loss_sum = 0 start = time.time() source_rand_perm = None target_rand_perm = None for step in range(args.start_iter, args.max_iter + 1): if source_rand_perm is None or step % len(source_train_dataset) == 1: source_rand_perm = np.random.permutation(len(source_train_dataset)) if target_rand_perm is None or step % len(target_train_dataset) == 1: target_rand_perm = np.random.permutation(len(target_train_dataset)) source_index = source_rand_perm[step % len(source_train_dataset)] target_index = target_rand_perm[step % len(target_train_dataset)] optimizer.zero_grad() if not target_only: source_batch = source_train_dataset.get_data( source_index, h_flip=np.random.rand() > 0.5, target_im_size=np.random.choice([480, 576, 688, 864, 1200])) source_im_data = source_batch['im_data'].unsqueeze(0).to(device) source_gt_labels = source_batch['gt_labels'] + 20 source_pos_cls = [i for i in range(80) if i in source_gt_labels] source_pos_cls = torch.tensor(np.random.choice( source_pos_cls, min(args.bs, len(source_pos_cls)), replace=False), dtype=torch.long, device=device) source_loss, _, _ = model(source_im_data, source_pos_cls) source_loss_sum += source_loss.item() target_batch = target_train_dataset.get_data( target_index, h_flip=np.random.rand() > 0.5, target_im_size=np.random.choice([480, 576, 688, 864, 1200])) target_im_data = target_batch['im_data'].unsqueeze(0).to(device) target_gt_labels = target_batch['gt_labels'] target_pos_cls = [i for i in range(80) if i in target_gt_labels] target_pos_cls = torch.tensor(np.random.choice( target_pos_cls, min(args.bs, len(target_pos_cls)), replace=False), dtype=torch.long, device=device) target_loss, _, _, _ = model(target_im_data, target_pos_cls) target_loss_sum += target_loss.item() if args.target_only: total_loss = target_loss else: total_loss = (source_loss + target_loss) * 0.5 total_loss.backward() total_loss_sum += total_loss.item() clip_gradient(model, 10.0) optimizer.step() if step % args.disp_interval == 0: end = time.time() total_loss_sum /= args.disp_interval source_loss_sum /= args.disp_interval target_loss_sum /= args.disp_interval log_message = "[%s][session %d][iter %4d] loss: %.8f, src_loss: %.8f, tar_loss: %.8f, lr: %.2e, time: %.1f" % \ (args.net, args.session, step, total_loss_sum, source_loss_sum, target_loss_sum, lr, end - start) print(log_message) log_file.write(log_message + '\n') log_file.flush() total_loss_sum = 0 source_loss_sum = 0 target_loss_sum = 0 start = time.time() if step in (args.max_iter * 4 // 7, args.max_iter * 6 // 7): adjust_learning_rate(optimizer, 0.1) lr *= 0.1 if step % args.save_interval == 0 or step == args.max_iter: save_name = os.path.join( output_dir, '{}_{}_{}.pth'.format(args.net, args.session, step)) checkpoint = dict() checkpoint['net'] = args.net checkpoint['session'] = args.session checkpoint['iterations'] = step checkpoint['model'] = model.state_dict() save_checkpoint(checkpoint, save_name) print('save model: {}'.format(save_name)) log_file.close()