1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
| def box_iou(b1, b2): '''Return iou tensor
Parameters ---------- b1: tensor, shape=(i1,...,iN, 4), xywh b2: tensor, shape=(j, 4), xywh
Returns ------- iou: tensor, shape=(i1,...,iN, j)
'''
b1 = K.expand_dims(b1, -2) b1_xy = b1[..., :2] b1_wh = b1[..., 2:4] b1_wh_half = b1_wh/2. b1_mins = b1_xy - b1_wh_half b1_maxes = b1_xy + b1_wh_half
b2 = K.expand_dims(b2, 0) b2_xy = b2[..., :2] b2_wh = b2[..., 2:4] b2_wh_half = b2_wh/2. b2_mins = b2_xy - b2_wh_half b2_maxes = b2_xy + b2_wh_half
intersect_mins = K.maximum(b1_mins, b2_mins) intersect_maxes = K.minimum(b1_maxes, b2_maxes) intersect_wh = K.maximum(intersect_maxes - intersect_mins, 0.) intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1] b1_area = b1_wh[..., 0] * b1_wh[..., 1] b2_area = b2_wh[..., 0] * b2_wh[..., 1] iou = intersect_area / (b1_area + b2_area - intersect_area)
return iou
def yolo_loss(args, anchors, num_classes, ignore_thresh=.5, print_loss=False): '''Return yolo_loss tensor
Parameters ---------- yolo_outputs: list of tensor, the output of yolo_body or tiny_yolo_body y_true: list of array, the output of preprocess_true_boxes anchors: array, shape=(N, 2), wh num_classes: integer ignore_thresh: float, the iou threshold whether to ignore object confidence loss
Returns ------- loss: tensor, shape=(1,)
''' num_layers = len(anchors)//3 yolo_outputs = args[:num_layers] y_true = args[num_layers:] anchor_mask = [[6,7,8], [3,4,5], [0,1,2]] if num_layers==3 else [[3,4,5], [1,2,3]] input_shape = K.cast(K.shape(yolo_outputs[0])[1:3] * 32, K.dtype(y_true[0])) grid_shapes = [K.cast(K.shape(yolo_outputs[l])[1:3], K.dtype(y_true[0])) for l in range(num_layers)] loss = 0 m = K.shape(yolo_outputs[0])[0] mf = K.cast(m, K.dtype(yolo_outputs[0]))
for l in range(num_layers): object_mask = y_true[l][..., 4:5] true_class_probs = y_true[l][..., 5:] grid, raw_pred, pred_xy, pred_wh = yolo_head(yolo_outputs[l], anchors[anchor_mask[l]], num_classes, input_shape, calc_loss=True) pred_box = K.concatenate([pred_xy, pred_wh])
raw_true_xy = y_true[l][..., :2]*grid_shapes[l][::-1] - grid raw_true_wh = K.log(y_true[l][..., 2:4] / anchors[anchor_mask[l]] * input_shape[::-1]) raw_true_wh = K.switch(object_mask, raw_true_wh, K.zeros_like(raw_true_wh)) box_loss_scale = 2 - y_true[l][...,2:3]*y_true[l][...,3:4]
ignore_mask = tf.TensorArray(K.dtype(y_true[0]), size=1, dynamic_size=True) object_mask_bool = K.cast(object_mask, 'bool') def loop_body(b, ignore_mask): true_box = tf.boolean_mask(y_true[l][b,...,0:4], object_mask_bool[b,...,0]) iou = box_iou(pred_box[b], true_box) best_iou = K.max(iou, axis=-1) ignore_mask = ignore_mask.write(b, K.cast(best_iou<ignore_thresh, K.dtype(true_box))) return b+1, ignore_mask _, ignore_mask = K.control_flow_ops.while_loop(lambda b,*args: b<m, loop_body, [0, ignore_mask]) ignore_mask = ignore_mask.stack() ignore_mask = K.expand_dims(ignore_mask, -1)
xy_loss = object_mask * box_loss_scale * K.binary_crossentropy(raw_true_xy, raw_pred[...,0:2], from_logits=True) wh_loss = object_mask * box_loss_scale * 0.5 * K.square(raw_true_wh-raw_pred[...,2:4]) confidence_loss = object_mask * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True) (1-object_mask) * K.binary_crossentropy(object_mask, raw_pred[...,4:5], from_logits=True) * ignore_mask class_loss = object_mask * K.binary_crossentropy(true_class_probs, raw_pred[...,5:], from_logits=True)
xy_loss = K.sum(xy_loss) / mf wh_loss = K.sum(wh_loss) / mf confidence_loss = K.sum(confidence_loss) / mf class_loss = K.sum(class_loss) / mf loss += xy_loss + wh_loss + confidence_loss + class_loss if print_loss: loss = tf.Print(loss, [loss, xy_loss, wh_loss, confidence_loss, class_loss, K.sum(ignore_mask)], message='loss: ') return loss
|